import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;

class ByteShift {
    public static final int SIZE = 257;
    public static final VectorSpecies<Byte> spec = ByteVector.SPECIES_PREFERRED;

    public static byte[] a = new byte[SIZE];
    public static byte[] b = new byte[SIZE];
    public static byte[] c = new byte[SIZE];

    public static void urshift(byte[] src, byte[] dst) {
        for (int i = 0; i < src.length; i++) {
            dst[i] = (byte)(src[i] >>> 3);
        }
    }

    public static void urshiftVector(byte[] src, byte[] dst) {
        int i = 0;
        for (; i < spec.loopBound(src.length); i +=spec.length()) {
            var va = ByteVector.fromArray(spec, src, i);
            var vb = va.lanewise(VectorOperators.LSHR, 3);
            vb.intoArray(dst, i);
        }

        for (; i < src.length; i++) {
            dst[i] = (byte)(src[i] >>> 3);
        }
    }

    public static void main(String[] args) {
        for (int i = 0; i < a.length; i++) {
            a[i] = (byte)i;
        }

        urshift(a, b);

        for (int i = 0; i < 10000; i++) {
            urshiftVector(a, c);
        }

        for (int i = 0; i < b.length; i++) {
            if (b[i] != c[i]) {
                System.out.println("i: " + i);
                System.out.println("a[i]: " + (byte)a[i]);
                System.out.println("scalar: " + (byte)b[i]);
                System.out.println("vector: " + (byte)c[i]);
                System.out.println("");
                System.exit(-1);
            }
        }

        System.out.println("Done");
    }
}
