import java.lang.foreign.*;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;

public final class CCB {
    static final StructLayout LAYOUT$S = MemoryLayout.structLayout(
            ValueLayout.ADDRESS.withTargetLayout(ValueLayout.JAVA_BYTE).withName("s1"),
            ValueLayout.ADDRESS.withTargetLayout(ValueLayout.JAVA_BYTE).withName("s2")
    );

    static final FunctionDescriptor DESCRIPTOR$ccb = FunctionDescriptor.ofVoid(ValueLayout.ADDRESS.withName("fn"));

    static final FunctionDescriptor DESCRIPTOR$callback = FunctionDescriptor.ofVoid(
            LAYOUT$S.withName("s"),
            ValueLayout.ADDRESS.withTargetLayout(ValueLayout.JAVA_BYTE).withName("data")
    );

    static final MethodHandle hCCB;
    static {
        System.loadLibrary("ccb");

        Linker linker = Linker.nativeLinker();
        SymbolLookup stdlibLookup = linker.defaultLookup();
        SymbolLookup loaderLookup = SymbolLookup.loaderLookup();

        MemorySegment pfnCCB = loaderLookup.find("ccb")
                .or(() -> stdlibLookup.find("ccb"))
                .orElse(MemorySegment.NULL);
        if (pfnCCB.equals(MemorySegment.NULL)) {
            throw new RuntimeException("Failed to find ccb symbol");
        }
        hCCB = linker.downcallHandle(pfnCCB, DESCRIPTOR$ccb);
    }

    static final class Ref<T> {
        T value;
    }

    @FunctionalInterface
    interface MemorySegmentConsumer {
        void accept(MemorySegment segment);
    }

    static void callback(
            MemorySegmentConsumer consumer,
            MemorySegment s,
            MemorySegment data
    ) {
        for (int i = 0; i < 2; i++) {
            MemorySegment segment = s.getAtIndex(ValueLayout.ADDRESS, i).reinterpret(Long.MAX_VALUE);
            System.err.println("(J) callback: s->s" + (i + 1) + " = " + segment.getString(0));
        }
        data = data.reinterpret(Long.MAX_VALUE);
        System.err.println("(J) callback: data = " + data.getString(0));
        System.err.println("(J) callback: address of data = " + Long.toUnsignedString(data.address(), 16));

        consumer.accept(data);
    }

    public static void main(String[] args) {
        Ref<MemorySegment> ref = new Ref<>();
        MemorySegmentConsumer consumer = segment -> ref.value = segment;

        try (Arena arena = Arena.ofConfined()) {
            Linker linker = Linker.nativeLinker();
            MethodHandle MH$callback = MethodHandles.lookup().findStatic(
                    CCB.class,
                    "callback",
                    DESCRIPTOR$callback.toMethodType().insertParameterTypes(0, MemorySegmentConsumer.class)
            );
            MemorySegment pfnCallback = linker.upcallStub(
                    MH$callback.bindTo(consumer),
                    DESCRIPTOR$callback,
                    arena
            );

            hCCB.invokeExact(pfnCallback);

            System.err.println("(J) main: data = " + ref.value.getString(0)); // <-- error occurs here
        } catch (Throwable e) {
            throw new RuntimeException(e);
        }
    }
}
