package org.openjdk.bench.valhalla;

import jdk.internal.value.ValueClass;
import jdk.internal.vm.annotation.LooselyConsistentValue;
import org.openjdk.jmh.annotations.*;

import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.function.IntFunction;

@BenchmarkMode(Mode.AverageTime)
@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@State(Scope.Thread)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Fork(jvmArgs = {"--enable-preview", "-XX:+UseArrayFlattening", "--add-exports=java.base/jdk.internal.vm.annotation=ALL-UNNAMED", "--add-exports=java.base/jdk.internal.value=ALL-UNNAMED", "--add-exports=java.base/jdk.internal.classfile.components=ALL-UNNAMED"}, value = 1)
public class FlatArrayMap {

    private static final Random RANDOM = new Random();

    private static final int SIZE = 1_000_000;

    @Benchmark
    public Point[] flatDontInline() {
        var result = createFlatDontInline(SIZE);
        for (int i = 0; i < SIZE; i++) {
            result[i] = flat[i].times(23);
        }
        return result;
    }

    @Benchmark
    public Point[] flatInline() {
        var result = createFlatInline(SIZE);
        for (int i = 0; i < SIZE; i++) {
            result[i] = flat[i].times(23);
        }
        return result;
    }

    @CompilerControl(CompilerControl.Mode.INLINE)
    static Point[] createFlatInline(int size) {
        return (Point[])ValueClass.newNullRestrictedNonAtomicArray(Point.class, size, Point.DEFAULT);
    }

    @CompilerControl(CompilerControl.Mode.DONT_INLINE)
    static Point[] createFlatDontInline(int size) {
        return (Point[])ValueClass.newNullRestrictedNonAtomicArray(Point.class, size, Point.DEFAULT);
    }

    private final Point[] flat = fill(createFlatInline(SIZE), FlatArrayMap::toPoint);

    private static Point toPoint(int i) {
        return new Point(RANDOM.nextInt(), RANDOM.nextInt());
    }

    private static Point[] fill(Point[] points, IntFunction<Point> factory) {
        for (int i = 0; i < points.length; i++) {
            points[i] = factory.apply(i);
        }
        return points;
    }

    @LooselyConsistentValue
    public value record Point(int x, int y) {
        public Point times(int scalar) {
            return new Point(x * scalar, y * scalar);
        }

        static final Point DEFAULT = new Point(0, 0);
    }
}

