import java.lang.foreign.*;
import java.lang.invoke.*;
import java.util.*;

import static java.lang.foreign.ValueLayout.*;

public class Main {

    static {
        System.loadLibrary("mylib");
    }

    static final Linker LINKER = Linker.nativeLinker();
    static final SymbolLookup SB = SymbolLookup.loaderLookup();

    public static void main(String[] args) throws Throwable {
        record TestCase(ValueLayout javaType, String nativeType) {}
        List<TestCase> cases = List.of(
            new TestCase(JAVA_BYTE, "int8_t"),
            new TestCase(JAVA_BYTE, "uint8_t"),
            new TestCase(JAVA_SHORT, "int16_t"),
            new TestCase(JAVA_SHORT, "uint16_t"),
            new TestCase(JAVA_INT, "int32_t"),
            new TestCase(JAVA_INT, "uint32_t"),
            new TestCase(JAVA_LONG, "int64_t"),
            new TestCase(JAVA_LONG, "uint64_t")
        );
        long[] testValues = { 0x01, 0x8f, 0x00, 0xf0, 0xff };

        for (TestCase testCase : cases) {
            test(testCase.javaType(), testCase.nativeType(), testValues);
        }
    }

    static void test(ValueLayout javaType, String nativeType, long[] testValues) throws Throwable {
        System.out.println("Testing: " + javaType + " = " + nativeType);
        for (long longValue : testValues) {
            Object value = convert(longValue, javaType.carrier());
            System.out.print("  " + value + ": ");
            MemorySegment msSet = SB.findOrThrow("jnhw_" + nativeType + "_set");
            MethodHandle mhSet = LINKER.downcallHandle(msSet, FunctionDescriptor.ofVoid(javaType));

            mhSet.invoke(value);

            MemorySegment msMem = SB.findOrThrow("jnhw_" + nativeType + "_mem").reinterpret(javaType.byteSize());

            assertEquals(value, javaType.varHandle().get(msMem, 0L));

            MemorySegment msGet = SB.findOrThrow("jnhw_" + nativeType + "_get");
            MethodHandle mhGet = LINKER.downcallHandle(msGet, FunctionDescriptor.of(javaType));

            assertEquals(value, mhGet.invoke());
            System.out.println("ok");
        }
    }

    static Object convert(long longValue, Class<?> targetType) throws Throwable {
        MethodHandle converter = MethodHandles.explicitCastArguments(
                MethodHandles.identity(targetType),
                MethodType.methodType(targetType, long.class));
        return converter.invoke(longValue);
    }

    static void assertEquals(Object a, Object b) {
        if (!Objects.equals(a, b)) {
            throw new AssertionError("Mismatch: " + a + " != " + b);
        }
    }
}
