package org.java.awt.image;

import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.infra.Blackhole;

import javax.imageio.*;
import java.awt.*;
import java.awt.image.*;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.Hashtable;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 5, time = 20)
@Fork(3)
@State(Scope.Thread)
public class JPEG_Progressive {

    byte[] jpgImageData;

    @Setup
    public void setup() throws Exception {
        jpgImageData = createImageData(2_500);

        // if you want to visually confirm that createBufferedImage() is
        // producing the correct image:
//        BufferedImage bi1 = ImageIO.read(new ByteArrayInputStream(jpgImageData));
//        ImageIO.write(bi1, "png", new File("xanth1.png"));
//
//        Image img = Toolkit.getDefaultToolkit().createImage(jpgImageData);
//        BufferedImage bi2 = createBufferedImage(img);
//        ImageIO.write(bi1, "png", new File("xanth2.png"));
    }

    @Benchmark
    public void measureImageIO(Blackhole bh) throws Exception {
        BufferedImage bi = ImageIO.read(new ByteArrayInputStream(jpgImageData));
        bi.flush();
        bh.consume(bi);
    }

    @Benchmark
    public void measureImageConsumer(Blackhole bh) throws Exception {
        Image img = Toolkit.getDefaultToolkit().createImage(jpgImageData);
        BufferedImage bi = createBufferedImage(img);
        bi.flush();
        bh.consume(bi);
    }

    private BufferedImage createBufferedImage(Image img) throws Exception {
        CompletableFuture<BufferedImage> future = new CompletableFuture<>();
        img.getSource().startProduction(new ImageConsumer() {
            int imageWidth, imageHeight;
            BufferedImage bi;

            @Override
            public void setDimensions(int width, int height) {
                imageWidth = width;
                imageHeight = height;
            }

            @Override
            public void setProperties(Hashtable<?, ?> props) {
                // intentionally empty
            }

            @Override
            public void setColorModel(ColorModel model) {
                // intentionally empty
            }

            @Override
            public void setHints(int hintflags) {
                // intentionally empty
            }

            @Override
            public void setPixels(int x, int y, int w, int h, ColorModel model, byte[] pixels, int off, int scansize) {
                throw new UnsupportedOperationException();
            }

            @Override
            public void setPixels(int x, int y, int w, int h, ColorModel model, int[] pixels, int off, int scansize) {
                if (bi == null) {
                    bi = new BufferedImage(imageWidth, imageHeight, BufferedImage.TYPE_INT_RGB);
                }
                if (w != imageWidth || !model.equals(bi.getColorModel()))
                    throw new UnsupportedOperationException();
                if (off != 0) {
                    int[] array = new int[w * h];
                    System.arraycopy(pixels, off, array, 0, array.length);
                    pixels = array;
                }
                bi.getRaster().setDataElements(x, y, w, h, pixels);
            }

            @Override
            public void imageComplete(int status) {
                future.complete(bi);
            }
        });

        BufferedImage bi = future.get();
        return bi;
    }

    /**
     * Create a large sample image stored as a JPG
     *
     * @return the byte representation of the JPG image.
     */
    private static byte[] createImageData(int squareSize) throws Exception {
        BufferedImage bi = new BufferedImage(squareSize, squareSize,
                BufferedImage.TYPE_INT_RGB);
        Random r = new Random(0);
        Graphics2D g = bi.createGraphics();
        for (int a = 0; a < 20000; a++) {
            g.setColor(new Color(r.nextInt(0xffffff)));
            int radius = 10 + r.nextInt(90);
            g.fillOval(r.nextInt(bi.getWidth()), r.nextInt(bi.getHeight()),
                    radius, radius);
        }
        g.dispose();

        try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {

            ImageWriter writer = ImageIO.getImageWritersByFormatName("jpg").next();
            ImageWriteParam param = writer.getDefaultWriteParam();

            // Set compression mode to explicit to allow further customization
            param.setCompressionMode(ImageWriteParam.MODE_EXPLICIT);
            param.setCompressionQuality(0.85f); // Adjust quality as needed

            // Enable progressive encoding (scans)
            param.setProgressiveMode(ImageWriteParam.MODE_DEFAULT);

            try {
                writer.setOutput(ImageIO.createImageOutputStream(out));
                writer.write(null, new IIOImage(bi, null, null), param);
            } finally {
                writer.dispose();
            }

            return out.toByteArray();
        }
    }
}