import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HexFormat;

import static java.lang.String.format;

public class Base64Avx3 {
    static final int POSITIONS = 30_000;
    static final int BASE_LENGTH = 256;
    static final HexFormat HEX_FORMAT = HexFormat.of().withUpperCase().withDelimiter(" ");

    static int[] plainOffsets = new int[POSITIONS + 1];
    static byte[] plainBytes;
    static int[] base64Offsets = new int[POSITIONS + 1];
    static byte[] base64Bytes;

    static {
        int plainLength = 0;
        for (int i = 0; i < plainOffsets.length; i++) {
            plainOffsets[i] = plainLength;
            int positionLength = (BASE_LENGTH + i) % 2048;
            plainLength += positionLength;
        }
        plainBytes = new byte[plainLength];
        for (int i = 0; i < plainBytes.length; i++) {
            plainBytes[i] = (byte) i;
        }

        ByteBuffer plainBuffer = ByteBuffer.wrap(plainBytes);
        int base64Length = 0;
        for (int i = 0; i < POSITIONS; i++) {
            base64Offsets[i] = base64Length;
            int offset = plainOffsets[i];
            int length = plainOffsets[i + 1] - offset;
            ByteBuffer plainSlice = plainBuffer.slice(offset, length);
            base64Length += Base64.getEncoder().encode(plainSlice).remaining();
        }
        base64Offsets[base64Offsets.length - 1] = base64Length;
        base64Bytes = new byte[base64Length];
        for (int i = 0; i < POSITIONS; i++) {
            int plainOffset = plainOffsets[i];
            ByteBuffer plainSlice = plainBuffer.slice(plainOffset, plainOffsets[i + 1] - plainOffset);
            ByteBuffer encodedBytes = Base64.getEncoder().encode(plainSlice);
            int base64Offset = base64Offsets[i];
            int expectedLength = base64Offsets[i + 1] - base64Offset;
            if (expectedLength != encodedBytes.remaining()) {
                throw new IllegalStateException(format("Unexpected length: %s <> %s", encodedBytes.remaining(), expectedLength));
            }
            encodedBytes.get(base64Bytes, base64Offset, expectedLength);
        }
    }

    public static void main(String[] args) {
        decodeAndCheck();
        System.out.println("Test complete, no invalid decodes detected");
    }

    static void decodeAndCheck() {
        for (int i = 0; i < POSITIONS; i++) {
            ByteBuffer encodedBytes = base64BytesAtPosition(i);
            ByteBuffer decodedBytes = Base64.getDecoder().decode(encodedBytes);

            if (!decodedBytes.equals(plainBytesAtPosition(i))) {
                String base64String = base64StringAtPosition(i);
                String plainHexString = plainHexStringAtPosition(i);
                String decodedHexString = HEX_FORMAT.formatHex(decodedBytes.array(), decodedBytes.arrayOffset() + decodedBytes.position(), decodedBytes.arrayOffset() + decodedBytes.limit());
                throw new IllegalStateException(format("Mismatch for %s\n\nExpected:\n%s\n\nActual:\n%s", base64String, plainHexString, decodedHexString));
            }
        }
    }

    static ByteBuffer plainBytesAtPosition(int position) {
        int offset = plainOffsets[position];
        int length = plainOffsets[position + 1] - offset;
        return ByteBuffer.wrap(plainBytes, offset, length);
    }

    static String plainHexStringAtPosition(int position) {
        int offset = plainOffsets[position];
        int length = plainOffsets[position + 1] - offset;
        return HEX_FORMAT.formatHex(plainBytes, offset, offset + length);
    }

    static String base64StringAtPosition(int position) {
        int offset = base64Offsets[position];
        int length = base64Offsets[position + 1] - offset;
        return new String(base64Bytes, offset, length, StandardCharsets.UTF_8);
    }

    static ByteBuffer base64BytesAtPosition(int position) {
        int offset = base64Offsets[position];
        int length = base64Offsets[position + 1] - offset;
        return ByteBuffer.wrap(base64Bytes, offset, length);
    }
}
