import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;

import jdk.jfr.Recording;
import jdk.jfr.consumer.RecordedClass;
import jdk.jfr.consumer.RecordedEvent;
import jdk.jfr.consumer.RecordedMethod;
import jdk.jfr.consumer.RecordedObject;
import jdk.jfr.consumer.RecordingFile;

public class MemoryLeak {
    static class OtherObject {
    }

    static class LeakObject {
    }

    private static final Random random = new Random();
    private static final Map<LeakObject, LeakObject> leaking = new ConcurrentHashMap<>();
    private static volatile int time = 0;

    public static class AllocatorThread extends Thread {
        private final int frequency;
        private volatile boolean alive = true;
        public Object object;

        public AllocatorThread(String name, int frequency) {
            super(name);
            this.frequency = frequency;
        }

        public void terminate() {
            alive = false;
        }

        public void run() {
            while (alive) {
                object = allocateObject(random.nextInt(5));
                if (random.nextInt(100) < frequency) {
                    if (object instanceof LeakObject lo) {
                        leaking.put(lo, lo);
                    }
                }
            }
        }

        private Object allocateObject(int value) {
            return switch (value) {
            case 0 -> new String("Hello");
            case 1 -> new HashMap<>();
            case 2 -> new ArrayList<>();
            case 3 -> new LeakObject();
            case 4 -> new OtherObject();
            default -> throw new Error("Unexpected value: " + value);
            };
        }
    }

    public static void main(String... args) throws Exception {
        int seconds = 100;
        int frequency = 1;
        if (args.length == 0) {
            System.out.println("Usage: java MemoryLeak [<seconds> <frequency>]");
            System.out.println("");
            System.out.println("Using default, seconds=100 and frequency=1 out of 100 allocations are leaking");
            System.out.println("");
        }
        if (args.length > 0) {
            seconds = Integer.parseUnsignedInt(args[0]);
        }
        if (args.length > 1) {
            frequency = Integer.parseUnsignedInt(args[0]);
        }

        try (Recording r = new Recording()) {
            r.enable("jdk.OldObjectSample").withStackTrace().with("cutoff", "0 s");
            r.start();
            List<AllocatorThread> threads = new ArrayList<>();
            for (int i = 0; i < 5; i++) {
                AllocatorThread t = new AllocatorThread("Allocator Thread " + i, frequency);
                threads.add(t);
                t.start();
            }
            System.out.println(" Time  | Leaking Objects | Total Memory");
            System.out.println("---------------------------------------");
            while (time < seconds) {
                Thread.sleep(1000);
                time++;
                long total = Runtime.getRuntime().totalMemory() / (1024 * 1024);
                System.out.printf("%5d s      %12d %10d MB\n", time, leaking.size(), total);
            }

            for (AllocatorThread t : threads) {
                t.terminate();
            }
            r.stop();
            Path p = Path.of("memory-leak.jfr");
            if (Files.exists(p)) {
                Files.delete(p);
            }
            r.dump(p);
            List<RecordedEvent> events = RecordingFile.readAllEvents(p);
            events.sort(new Comparator<RecordedEvent>() {
                @Override
                public int compare(RecordedEvent e1, RecordedEvent e2) {
                    Instant a1 = e1.getInstant("allocationTime");
                    Instant a2 = e1.getInstant("allocationTime");
                    return a1.compareTo(a2);
                }
            });
            System.out.println();
            System.out.println(" Time  | Class                      | Allocation Site");
            System.out.println("-------------------------------------------------------------------");
            for (RecordedEvent e : events) {
                Instant start = events.get(0).getInstant("allocationTime");
                Instant allocationTime = e.getInstant("allocationTime");
                Duration d = Duration.between(start, allocationTime);
                double s = d.toNanos() / 1_000_000_000.0;
                // System.out.println(e);
                RecordedObject leakObject = e.getValue("object");
                RecordedClass type = leakObject.getClass("type");
                String className = compactClass(type);
                className = className.substring(className.lastIndexOf(".") + 1);
                RecordedMethod m = e.getStackTrace().getFrames().get(0).getMethod();
                String allocationSite = compactClass(m.getType()) + "::" + m.getName() + "(...)";
                System.out.printf("%.3f   %-25s    %s\n", s, className, allocationSite);
            }
        }
    }

    private static String compactClass(RecordedClass type) {
        String className = type.getName();
        className = className.substring(className.lastIndexOf(".") + 1);
        className = className.substring(className.lastIndexOf("$") + 1);
        return className;
    }
}