package org.openjdk.bench.java.lang.foreign;

import java.lang.foreign.*;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import java.lang.invoke.MethodHandle;
import java.util.concurrent.TimeUnit;
import java.util.function.IntFunction;

@BenchmarkMode(Mode.AverageTime)
@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@State(org.openjdk.jmh.annotations.Scope.Thread)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(value = 3, jvmArgs = { "--enable-preview", "--enable-native-access=ALL-UNNAMED", "-Djava.library.path=micro/native" })
public class TestIdentity {
    static class Bindings {

        static final Linker LINKER = Linker.nativeLinker();

        final StableValue<IntFunction<MethodHandle>> downcallsStable = StableValue.of(StableValue.ofIntFunction(1, this::downcallFor));
        final IntFunction<MethodHandle> downcallsFinal = StableValue.ofIntFunction(1, this::downcallFor);
        final SymbolLookup lookup;

        public Bindings(SymbolLookup lookup) {
            this.lookup = lookup;
        }

        int identityStable(int arg) {
            try {
                return (int)downcallsStable.orElseThrow().apply(0).invokeExact(arg);
            } catch (Throwable ex) {
                throw new IllegalStateException(ex);
            }
        }

        int identityFinal(int arg) {
            try {
                return (int)downcallsFinal.apply(0).invokeExact(arg);
            } catch (Throwable ex) {
                throw new IllegalStateException(ex);
            }
        }

        MethodHandle downcallFor(int i) {
            switch (i) {
                case 0: {
                    MemorySegment address = lookup.findOrThrow("identity");
                    FunctionDescriptor descriptor = FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT);
                    return LINKER.downcallHandle(address, descriptor);
                }
                default:
                    throw new IllegalArgumentException();
            }
        }

        static final Bindings LIB = new Bindings(SymbolLookup.loaderLookup());
    }

    @Benchmark
    public int identity_jextract_stable() {
        return Bindings.LIB.identityStable(42);
    }

    @Benchmark
    public int identity_jextract_final() {
        return Bindings.LIB.identityFinal(42);
    }

    @Benchmark
    public int identity_ffm() throws Throwable {
        return (int) IDENTITY_MH.invokeExact(42);
    }

    static final MethodHandle IDENTITY_MH;

    static {
        System.loadLibrary("CallOverhead");
        MemorySegment address = SymbolLookup.loaderLookup().findOrThrow("identity");
        FunctionDescriptor descriptor = FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT);
        IDENTITY_MH = Linker.nativeLinker().downcallHandle(address, descriptor);
    }
}
