import jdk.incubator.vector.VectorShape;
import jdk.incubator.vector.VectorSpecies;
import jdk.incubator.vector.VectorShuffle;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorMath;

import jdk.incubator.vector.IntVector;

import java.util.function.IntFunction;

public class Test1 {

    static final VectorSpecies<Integer> SPECIES =
                IntVector.SPECIES_256;

    static final int INVOC_COUNT = 100;

    static final int BUFFER_REPS = 25000 / 256;

    interface FReductionOp {
        int apply(int[] a, int idx);
    }

    interface FReductionAllOp {
        int apply(int[] a);
    }

    static void assertReductionArraysEquals(int[] r, int rc, int[] a,
                                            FReductionOp f, FReductionAllOp fa) {
        int i = 0;
        try {
            if (rc != fa.apply(a)) { throw new RuntimeException("bad1"); }
            for (; i < a.length; i += SPECIES.length()) {
                if (r[i] != f.apply(a, i)) { throw new RuntimeException("bad2"); }
            }
        } catch (AssertionError e) {
            if (rc != fa.apply(a)) { throw new RuntimeException("bad3"); }
            if (r[i] != f.apply(a, i)) { throw new RuntimeException("bad4"); }
        }
    }

    interface ToIntF {
        int apply(int i);
    }

    static int[] fill(int s , ToIntF f) {
        return fill(new int[s], f);
    }

    static int[] fill(int[] a, ToIntF f) {
        for (int i = 0; i < a.length; i++) {
            a[i] = f.apply(i);
        }
        return a;
    }

    static final IntFunction<int[]> fr = (vl) -> {
        int length = BUFFER_REPS * vl;
        return new int[length];
    };

    static int ADDReduce(int[] a, int idx) {
        int res = 0;
        for (int i = idx; i < (idx + SPECIES.length()); i++) {
            res += a[i];
        }

        return res;
    }

    static int ADDReduceAll(int[] a) {
        int res = 0;
        for (int i = 0; i < a.length; i += SPECIES.length()) {
            res += ADDReduce(a, i);
        }

        return res;
    }

    static void test(IntFunction<int[]> fa) {
        int[] a = fa.apply(SPECIES.length());
        int[] r = fr.apply(SPECIES.length());
        int ra = 0;

        for (int ic = 0; ic < INVOC_COUNT; ic++) {
            for (int i = 0; i < a.length; i += SPECIES.length()) {
                IntVector av = IntVector.fromArray(SPECIES, a, i);
                r[i] = av.reduceLanes(VectorOperators.ADD);
            }
        }

        for (int ic = 0; ic < INVOC_COUNT; ic++) {
            ra = 0;
            for (int i = 0; i < a.length; i += SPECIES.length()) {
                IntVector av = IntVector.fromArray(SPECIES, a, i);
                ra += av.reduceLanes(VectorOperators.ADD);
            }
        }

        assertReductionArraysEquals(r, ra, a,
                Test1::ADDReduce, Test1::ADDReduceAll);
    }

//    static final List<IntFunction<int[]>> INT_GENERATORS = List.of(
//            withToString("int[-i * 5]", (int s) -> {
//                return fill(s * BUFFER_REPS,
//                            i -> (int)(-i * 5));
//            }),
//            withToString("int[i * 5]", (int s) -> {
//                return fill(s * BUFFER_REPS,
//                            i -> (int)(i * 5));
//            })
//	    // TODO: why are both needed?
//    );
//
//    public Object[][] intUnaryOpProvider() {
//        return INT_GENERATORS.stream().
//                map(f -> new Object[]{f}).
//                toArray(Object[][]::new);
//    }


    // ------------------------ COPIED CODE -------------------------

    public static void main(String[] args) {
	System.out.println("test v1");
        for (int i = 0; i < 100; i++) {
            test((int s) -> fill(s * BUFFER_REPS, j -> (int)(-j * 5)));
            test((int s) -> fill(s * BUFFER_REPS, j -> (int)(j * 5)));
	}
	//System.out.println("test v2");
        //for (int i = 0; i < 100; i++) {
        //    test((int s) -> fill(s * BUFFER_REPS, j -> (int)(j * 5)));
	//}
    }
}
