package jmh;

import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import java.util.Arrays;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 3, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(value = 1, jvmArgsPrepend = {"--add-modules=jdk.incubator.vector", "-Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=2"})
public class TestHashcode {
    static final VectorSpecies<Integer> INT_256_SPECIES = IntVector.SPECIES_256;

    static final VectorSpecies<Byte> BYTE_64_SPECIES = ByteVector.SPECIES_64;
    static final VectorSpecies<Byte> BYTE_128_SPECIES = ByteVector.SPECIES_128;
    static final VectorSpecies<Byte> BYTE_256_SPECIES = ByteVector.SPECIES_256;

    static final int COEFF_31_TO_8;
    static final int COEFF_31_TO_16;
    static final int COEFF_31_TO_24;
    static final int COEFF_31_TO_32;

    static final IntVector H_COEFF_31_TO_8;
    static final IntVector H_COEFF_31_TO_16;
    static final IntVector H_COEFF_31_TO_24;
    static final IntVector H_COEFF_31_TO_32;

    static final IntVector H_COEFF_8;
    static final IntVector H_COEFF_16;
    static final IntVector H_COEFF_24;
    static final IntVector H_COEFF_32;


    static {
        int[] a = new int[INT_256_SPECIES.length() * 4];
        a[a.length - 1] = 1;
        for (int i = 1; i < a.length; i++) {
            a[a.length - 1 - i] = a[a.length - 1 - i + 1] * 31;
        }

        COEFF_31_TO_8 = a[24] * 31;
        COEFF_31_TO_16 = a[16] * 31;
        COEFF_31_TO_24 = a[8] * 31;
        COEFF_31_TO_32 = a[0] * 31;

        H_COEFF_31_TO_8 = IntVector.broadcast(INT_256_SPECIES, COEFF_31_TO_8);
        H_COEFF_31_TO_16 = IntVector.broadcast(INT_256_SPECIES, COEFF_31_TO_16);
        H_COEFF_31_TO_24 = IntVector.broadcast(INT_256_SPECIES, COEFF_31_TO_24);
        H_COEFF_31_TO_32 = IntVector.broadcast(INT_256_SPECIES, COEFF_31_TO_32);

        H_COEFF_8 = IntVector.fromArray(INT_256_SPECIES, a, 24);
        H_COEFF_16 = IntVector.fromArray(INT_256_SPECIES, a, 16);
        H_COEFF_24 = IntVector.fromArray(INT_256_SPECIES, a, 8);
        H_COEFF_32 = IntVector.fromArray(INT_256_SPECIES, a, 0);
    }

    @Param("1024")
    int size;

    byte[] a;

    @Setup
    public void init() {
        a = new byte[size];
        ThreadLocalRandom.current().nextBytes(a);
    }

    @Benchmark
    public int vector64ReduceInLoop() {
        int h = 1;
        int i = 0;
        // Force into registers
        IntVector c_H_COEFF_8 = H_COEFF_8.add(0);
        for (; i < BYTE_64_SPECIES.loopBound(a.length); i += BYTE_64_SPECIES.length()) {
            ByteVector b = ByteVector.fromArray(BYTE_64_SPECIES, a, i);
            IntVector x = (IntVector) b.castShape(INT_256_SPECIES, 0);
            h = h * COEFF_31_TO_8 + x.mul(c_H_COEFF_8).reduceLanes(VectorOperators.ADD);
        }

        for (; i < a.length; i++) {
            h = 31 * h + a[i];
        }
        return h;
    }

    @Benchmark
    public int vector64() {
        IntVector h = IntVector.fromArray(INT_256_SPECIES, new int[]{1, 0, 0, 0, 0, 0, 0, 0}, 0);
        int i = 0;
        // Force into registers
        IntVector c_H_COEFF_8 = H_COEFF_8.add(0);
        IntVector c_H_COEFF_31_TO_8 = H_COEFF_31_TO_8.add(0);
        for (; i < BYTE_64_SPECIES.loopBound(a.length); i += BYTE_64_SPECIES.length()) {
            ByteVector b = ByteVector.fromArray(BYTE_64_SPECIES, a, i);
            IntVector x = (IntVector) b.castShape(INT_256_SPECIES, 0);
            h = h.mul(c_H_COEFF_31_TO_8).add(x.mul(c_H_COEFF_8));
        }

        int sh = h.reduceLanes(VectorOperators.ADD);
        for (; i < a.length; i++) {
            sh = 31 * sh + a[i];
        }
        return sh;
    }

    @Benchmark
    public int vector64Unrolledx2() {
        IntVector h1 = IntVector.fromArray(INT_256_SPECIES, new int[]{1, 0, 0, 0, 0, 0, 0, 0}, 0);
        IntVector h2 = IntVector.zero(INT_256_SPECIES);
        int i = 0;
        // Force into registers
        IntVector c_H_COEFF_16 = H_COEFF_16.add(0);
        IntVector c_H_COEFF_8 = H_COEFF_8.add(0);
        IntVector c_H_COEFF_31_TO_16 = H_COEFF_31_TO_16.add(0);
        for (; i < (a.length & ~(BYTE_128_SPECIES.length() - 1)); i += BYTE_128_SPECIES.length()) {
            ByteVector b = ByteVector.fromArray(BYTE_64_SPECIES, a, i);
            IntVector x = (IntVector) b.castShape(INT_256_SPECIES, 0);
            h1 = h1.mul(c_H_COEFF_31_TO_16).add(x.mul(c_H_COEFF_16));

            b = ByteVector.fromArray(BYTE_64_SPECIES, a, i + BYTE_64_SPECIES.length());
            x = (IntVector) b.castShape(INT_256_SPECIES, 0);
            h2 = h2.mul(c_H_COEFF_31_TO_16).add(x.mul(c_H_COEFF_8));
        }

        int sh = h1.reduceLanes(VectorOperators.ADD) + h2.reduceLanes(VectorOperators.ADD);
        for (; i < a.length; i++) {
            sh = 31 * sh + a[i];
        }
        return sh;
    }

    @Benchmark
    public int vector64Unrolledx4() {
        IntVector h1 = IntVector.fromArray(INT_256_SPECIES, new int[]{1, 0, 0, 0, 0, 0, 0, 0}, 0);
        IntVector h2 = IntVector.zero(INT_256_SPECIES);
        IntVector h3 = IntVector.zero(INT_256_SPECIES);
        IntVector h4 = IntVector.zero(INT_256_SPECIES);
        int i = 0;
        // Force into registers
        IntVector c_H_COEFF_8 = H_COEFF_8.add(0);
        IntVector c_H_COEFF_16 = H_COEFF_16.add(0);
        IntVector c_H_COEFF_24 = H_COEFF_24.add(0);
        IntVector c_H_COEFF_32 = H_COEFF_32.add(0);
        IntVector c_H_COEFF_31_TO_32 = H_COEFF_31_TO_32.add(0);
        for (; i < (a.length & ~(BYTE_256_SPECIES.length() - 1)); i += BYTE_256_SPECIES.length()) {
            ByteVector b = ByteVector.fromArray(BYTE_64_SPECIES, a, i);
            IntVector x = (IntVector) b.castShape(INT_256_SPECIES, 0);
            h1 = h1.mul(c_H_COEFF_31_TO_32).add(x.mul(c_H_COEFF_32));

            b = ByteVector.fromArray(BYTE_64_SPECIES, a, i + BYTE_64_SPECIES.length());
            x = (IntVector) b.castShape(INT_256_SPECIES, 0);
            h2 = h2.mul(c_H_COEFF_31_TO_32).add(x.mul(c_H_COEFF_24));

            b = ByteVector.fromArray(BYTE_64_SPECIES, a, i + BYTE_64_SPECIES.length() * 2);
            x = (IntVector) b.castShape(INT_256_SPECIES, 0);
            h3 = h3.mul(c_H_COEFF_31_TO_32).add(x.mul(c_H_COEFF_16));

            b = ByteVector.fromArray(BYTE_64_SPECIES, a, i + BYTE_64_SPECIES.length() * 3);
            x = (IntVector) b.castShape(INT_256_SPECIES, 0);
            h4 = h4.mul(c_H_COEFF_31_TO_32).add(x.mul(c_H_COEFF_8));
        }

        int sh = h1.reduceLanes(VectorOperators.ADD) +
                h2.reduceLanes(VectorOperators.ADD) +
                h3.reduceLanes(VectorOperators.ADD) +
                h4.reduceLanes(VectorOperators.ADD);
        for (; i < a.length; i++) {
            sh = 31 * sh + a[i];
        }
        return sh;
    }

    @Benchmark
    public int vector128Unrolledx2() {
        IntVector h1 = IntVector.fromArray(INT_256_SPECIES, new int[]{1, 0, 0, 0, 0, 0, 0, 0}, 0);
        IntVector h2 = IntVector.fromArray(INT_256_SPECIES, new int[]{0, 0, 0, 0, 0, 0, 0, 0}, 0);
        int i = 0;
        // Force into registers
        IntVector c_H_COEFF_16_P0 = H_COEFF_16.add(0);
        IntVector c_H_COEFF_16_P1 = H_COEFF_8.add(0);
        IntVector c_H_COEFF_31_TO_16 = H_COEFF_31_TO_16.add(0);
        for (; i < (a.length & ~(BYTE_128_SPECIES.length() - 1)); i += BYTE_128_SPECIES.length()) {
            ByteVector b = ByteVector.fromArray(BYTE_128_SPECIES, a, i);
            IntVector x = (IntVector) b.castShape(INT_256_SPECIES, 0);
            h1 = h1.mul(c_H_COEFF_31_TO_16).add(x.mul(c_H_COEFF_16_P0));

            x = (IntVector) b.castShape(INT_256_SPECIES, 1);
            h2 = h2.mul(c_H_COEFF_31_TO_16).add(x.mul(c_H_COEFF_16_P1));
        }

        int sh = h1.reduceLanes(VectorOperators.ADD) + h2.reduceLanes(VectorOperators.ADD);
        for (; i < a.length; i++) {
            sh = 31 * sh + a[i];
        }
        return sh;
    }

    @Benchmark
    public int vector256Unrolledx4() {
        IntVector h1 = IntVector.fromArray(INT_256_SPECIES, new int[]{1, 0, 0, 0, 0, 0, 0, 0}, 0);
        IntVector h2 = IntVector.zero(INT_256_SPECIES);
        IntVector h3 = IntVector.zero(INT_256_SPECIES);
        IntVector h4 = IntVector.zero(INT_256_SPECIES);
        int i = 0;
        // Force into registers
        IntVector c_H_COEFF_8 = H_COEFF_8.add(0);
        IntVector c_H_COEFF_16 = H_COEFF_16.add(0);
        IntVector c_H_COEFF_24 = H_COEFF_24.add(0);
        IntVector c_H_COEFF_32 = H_COEFF_32.add(0);
        IntVector c_H_COEFF_31_TO_32 = H_COEFF_31_TO_32.add(0);
        for (; i < (a.length & ~(BYTE_256_SPECIES.length() - 1)); i += BYTE_256_SPECIES.length()) {
            ByteVector b = ByteVector.fromArray(BYTE_256_SPECIES, a, i);
            IntVector x = (IntVector) b.castShape(INT_256_SPECIES, 0);
            h1 = h1.mul(c_H_COEFF_31_TO_32).add(x.mul(c_H_COEFF_32));

            x = (IntVector) b.castShape(INT_256_SPECIES, 1);
            h2 = h2.mul(c_H_COEFF_31_TO_32).add(x.mul(c_H_COEFF_24));

            x = (IntVector) b.castShape(INT_256_SPECIES, 2);
            h3 = h3.mul(c_H_COEFF_31_TO_32).add(x.mul(c_H_COEFF_16));

            x = (IntVector) b.castShape(INT_256_SPECIES, 3);
            h4 = h4.mul(c_H_COEFF_31_TO_32).add(x.mul(c_H_COEFF_8));
        }

        int sh = h1.reduceLanes(VectorOperators.ADD) +
                h2.reduceLanes(VectorOperators.ADD) +
                h3.reduceLanes(VectorOperators.ADD) +
                h4.reduceLanes(VectorOperators.ADD);
        for (; i < a.length; i++) {
            sh = 31 * sh + a[i];
        }
        return sh;
    }

    @Benchmark
    public int scalar() {
        return Arrays.hashCode(a);
    }

    @Benchmark
    public int scalarUnrolled() {
        if (a == null)
            return 0;

        int h = 1;
        int i = 0;
        for (; i < (a.length & ~(8 - 1)); i += 8) {
            h = h * 31 * 31 * 31 * 31 * 31 * 31 * 31 * 31 +
                    a[i + 0] * 31 * 31 * 31 * 31 * 31 * 31 * 31 +
                    a[i + 1] * 31 * 31 * 31 * 31 * 31 * 31 +
                    a[i + 2] * 31 * 31 * 31 * 31 * 31 +
                    a[i + 3] * 31 * 31 * 31 * 31 +
                    a[i + 4] * 31 * 31 * 31 +
                    a[i + 5] * 31 * 31 +
                    a[i + 6] * 31 +
                    a[i + 7];
        }

        for (; i < a.length; i++) {
            h = 31 * h + a[i];
        }
        return h;
    }
}
