import java.lang.classfile.BootstrapMethodEntry;
import java.lang.classfile.constantpool.*;
import java.lang.constant.*;
import java.lang.constant.DirectMethodHandleDesc.Kind;
import java.lang.invoke.LambdaMetafactory;
import java.lang.invoke.StringConcatFactory;
import java.lang.invoke.TypeDescriptor;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.SequencedSet;
import java.util.function.Function;

import static java.lang.constant.ConstantDescs.*;

public final class PoolEntryEqHashBug {
    /**
     * @param args the command line arguments
     */
    public static final void main(final String... args) {
        var pool1 = ConstantPoolBuilder.of();
        var pool2 = ConstantPoolBuilder.of();

// ensure that pool1 and pool2 have different indicies:
        prefillWithGarbage(pool1);

        SequencedSet<Class<?>> badPoolEntries
                = new LinkedHashSet<>();

// Root entries
        testPoolEntry(badPoolEntries, Utf8Entry.class, pool1, pool2, pool -> pool.utf8Entry("Test Utf8Entry"));
        testPoolEntry(badPoolEntries, IntegerEntry.class, pool1, pool2, pool -> pool.intEntry(12345));
        testPoolEntry(badPoolEntries, FloatEntry.class, pool1, pool2, pool -> pool.floatEntry(12345f));
        testPoolEntry(badPoolEntries, LongEntry.class, pool1, pool2, pool -> pool.longEntry(12345L));
        testPoolEntry(badPoolEntries, DoubleEntry.class, pool1, pool2, pool -> pool.doubleEntry(12345d));

// Ref entries
        testPoolEntry(badPoolEntries, ClassEntry.class, pool1, pool2, pool -> pool.classEntry(CD_Object));
        testPoolEntry(badPoolEntries, StringEntry.class, pool1, pool2, pool -> pool.stringEntry("Test String"));

        testPoolEntry(badPoolEntries, FieldRefEntry.class, pool1, pool2, pool -> pool.fieldRefEntry(CD_Boolean, "TRUE", CD_Boolean));
        testPoolEntry(badPoolEntries, MethodRefEntry.class, pool1, pool2, pool -> pool.methodRefEntry(CD_Exception, INIT_NAME, MTD_void));
        testPoolEntry(badPoolEntries, InterfaceMethodRefEntry.class, pool1, pool2, pool -> pool.interfaceMethodRefEntry(CD_Collection, "isEmpty", MethodTypeDesc.of(CD_boolean)));

        testPoolEntry(badPoolEntries, NameAndTypeEntry.class, pool1, pool2, pool -> pool.nameAndTypeEntry("foo", MethodTypeDesc.of(CD_Object)));
        testPoolEntry(badPoolEntries, MethodHandleEntry.class, pool1, pool2, pool -> pool.methodHandleEntry(BSM_INVOKE));
        testPoolEntry(badPoolEntries, MethodTypeEntry.class, pool1, pool2, pool -> pool.methodTypeEntry(MethodTypeDesc.of(CD_String)));

        testPoolEntry(badPoolEntries, ConstantDynamicEntry.class, pool1, pool2, pool -> pool.constantDynamicEntry(FALSE));
        testPoolEntry(badPoolEntries, InvokeDynamicEntry.class, pool1, pool2, pool -> pool.invokeDynamicEntry(
                DynamicCallSiteDesc.of(
                        ConstantDescs.ofCallsiteBootstrap(
                                ClassDesc.of(StringConcatFactory.class.getName()),
                                "makeConcat",
                                CD_CallSite
                        ),
                        MethodTypeDesc.of(CD_String, CD_Object, CD_Object, CD_Object)
                )
        ));

        testPoolEntry(badPoolEntries, ModuleEntry.class, pool1, pool2, pool -> pool.moduleEntry(ModuleDesc.of("java.base")));
        testPoolEntry(badPoolEntries, PackageEntry.class, pool1, pool2, pool -> pool.packageEntry(PackageDesc.ofInternalName("java/lang")));

        testPoolEntry(badPoolEntries, BootstrapMethodEntry.class, pool1, pool2, pool -> pool.bsmEntry(
                ConstantDescs.ofCallsiteBootstrap(
                        ClassDesc.of(LambdaMetafactory.class.getName()),
                        "metafactory",
                        CD_CallSite,
                        CD_MethodType,
                        CD_MethodHandle,
                        CD_MethodType
                ),
                List.of(
                        MethodTypeDesc.of(CD_Object, CD_Object),
                        MethodHandleDesc.ofMethod(Kind.VIRTUAL, CD_Object, "hashCode", MethodTypeDesc.of(CD_int)),
                        MethodTypeDesc.of(CD_int, CD_Object)
                )
        ));

        if (!badPoolEntries.isEmpty()) {
            throw new AssertionError(badPoolEntries);
        }
    }

    private static final <T> void testPoolEntry(
            final SequencedSet<Class<?>> badPoolEntries,
            final Class<T> type,
            final ConstantPoolBuilder pool1,
            final ConstantPoolBuilder pool2,
            final Function<? super ConstantPoolBuilder, ? extends T> factory
    ) {
        final var entry1 = type.cast(factory.apply(pool1));
        final var entry2 = type.cast(factory.apply(pool2));

        if (entry1.equals(entry2) && entry1.hashCode() != entry2.hashCode()) {
            badPoolEntries.add(type);
        }
    }

    private static final void prefillWithGarbage(final ConstantPoolBuilder pool) {
        for (int i = 0; i < 10; i += 1) {
            pool.utf8Entry("ignore: " + i);
        }

        pool.bsmEntry(
                MethodHandleDesc.ofMethod(
                        Kind.STATIC,
                        ClassDesc.of(DEFAULT_NAME),
                        DEFAULT_NAME,
                        MethodTypeDesc.of(CD_Object, CD_MethodHandles_Lookup, CD_String, ClassDesc.of(TypeDescriptor.class.getName()))
                ),
                List.of()
        );
    }
}