package jmh;

import jdk.incubator.foreign.MemoryAccess;
import jdk.incubator.foreign.MemoryLayout;
import jdk.incubator.foreign.MemoryLayouts;
import jdk.incubator.foreign.MemorySegment;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Threads;
import org.openjdk.jmh.annotations.Warmup;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.IntBuffer;
import java.util.concurrent.TimeUnit;

@State(Scope.Thread)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)

@Warmup(iterations = 5, time = 1000000, timeUnit = TimeUnit.MICROSECONDS)
@Measurement(iterations = 3, time = 1000000, timeUnit = TimeUnit.MICROSECONDS)
@Threads(value = 1)
@Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.foreign", "--enable-preview"})
public class TestBuffer {

    @State(Scope.Benchmark)
    public static class BenchmarkState {
        final int N = 50000;
        final int[] array = new int[N];
        final IntBuffer buffer = IntBuffer.allocate(N);
        final int offset = buffer.arrayOffset();
        final IntBuffer heap_buffer_byte_to_int = ByteBuffer.allocate(N * Integer.BYTES).order(ByteOrder.nativeOrder()).asIntBuffer();
        final IntBuffer direct_buffer_byte_to_int = ByteBuffer.allocateDirect(N * Integer.BYTES).order(ByteOrder.nativeOrder()).asIntBuffer();
        ;
        final MemorySegment nativeSegment = MemorySegment.allocateNative(MemoryLayout.ofSequence(N, MemoryLayouts.JAVA_INT));
        final MemorySegment heapSegment = MemorySegment.ofArray(array);

        public BenchmarkState() {
            for (int k = 0; k < array.length; k++) {
                array[k] = k;
                buffer.put(k, k);
                heap_buffer_byte_to_int.put(k, k);
                direct_buffer_byte_to_int.put(k, k);
            }
            System.out.println("buffer.hasArray(): " + buffer.hasArray());
            System.out.println("heap_buffer_byte_to_int.hasArray(): " + heap_buffer_byte_to_int.hasArray());
            System.out.println("direct_buffer_byte_to_int.hasArray(): " + direct_buffer_byte_to_int.hasArray());
            System.out.println("buffer.order: " + buffer.order());
            System.out.println("heap_buffer_byte_to_int.order: " + heap_buffer_byte_to_int.order());
            System.out.println("direct_buffer_byte_to_int.order: " + direct_buffer_byte_to_int.order());
        }
    }

    @TearDown
    public void tearDown(BenchmarkState s) {
        s.nativeSegment.close();
        s.heapSegment.close();
    }

    @Benchmark
    public void array(BenchmarkState s) {
        for (int k = 0; k < s.array.length; k++) {
            s.array[k] += 1;
        }
    }

    @Benchmark
    public void arrayOffset(BenchmarkState s) {
        int l = s.array.length - s.offset;
        for (int k = 0; k < l; k++) {
            s.array[k + s.offset] += 1;
        }
    }

    @Benchmark
    public void buffer(BenchmarkState s) {
        for (int k = 0; k < s.buffer.limit(); k++) {
            s.buffer.put(k, s.buffer.get(k) + 1);
        }
    }

    @Benchmark
    public void segmentNative(BenchmarkState s) {
        for (int k = 0; k < (int) s.nativeSegment.byteSize(); k += 4) {
            int v = MemoryAccess.getIntAtOffset(s.nativeSegment, k);
            MemoryAccess.setIntAtOffset(s.nativeSegment, k, v + 1);
        }
    }

    @Benchmark
    public void segmentHeap(BenchmarkState s) {
        for (int k = 0; k < (int) s.heapSegment.byteSize(); k += 4) {
            int v = MemoryAccess.getIntAtOffset(s.heapSegment, k);
            MemoryAccess.setIntAtOffset(s.heapSegment, k, v + 1);
        }
    }

    @Benchmark
    public void bufferDirect_byte_to_int(BenchmarkState s) {
        for(int k = 0; k  < s.direct_buffer_byte_to_int.limit(); k++) {
            s.direct_buffer_byte_to_int.put(k, s.direct_buffer_byte_to_int.get(k) + 1);
        }
    }

    @Benchmark
    public void bufferHeap_byte_to_int(BenchmarkState s) {
        for (int k = 0; k < s.heap_buffer_byte_to_int.limit(); k++) {
            s.heap_buffer_byte_to_int.put(k, s.heap_buffer_byte_to_int.get(k) + 1);
        }
    }
}
