import jdk.jfr.Configuration;
import jdk.jfr.Recording;
import java.nio.file.Path;
import java.time.Duration;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;

public class VirtualThreadJfr {
    private static final AtomicLong computations = new AtomicLong(0);

    public static void main(String[] args) throws Exception {
        Duration duration = Duration.ofSeconds(10);
        int num_virt_threads = 10_000_000;

        if (args.length == 2) {
            try {
                duration = Duration.ofSeconds(Integer.parseInt(args[0]));
                num_virt_threads = Integer.parseInt(args[1]);
            } catch (NumberFormatException e) {
                System.err.println("Invalid duration argument: " + args[0]);
                System.exit(1);
            }
        }

        System.out.println(
                "Starting VirtualThreadJfr for " + duration + " ms with " + num_virt_threads + " virtual threads");

        try (Recording recording = new Recording(Configuration.getConfiguration("default"))) {
            recording.setDestination(Path.of("virtual_thread_stress.jfr"));
            recording.start();
            recording.enable("jdk.CPUTimeSampleLoss");
            recording.enable("jdk.CPUTimeSample")
                    .withStackTrace()
                    .withoutThreshold()
                    .withPeriod(Duration.ofNanos(1));
            recording.enable("jdk.VirtualThreadPinned")
                    .withStackTrace()
                    .withoutThreshold()
                    .withPeriod(Duration.ofNanos(1));
            recording.enable("jdk.VirtualThreadSubmitFailed")
                    .withStackTrace()
                    .withoutThreshold()
                    .withPeriod(Duration.ofNanos(1));

            try (ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor()) {
                for (int i = 0; i < num_virt_threads; i++) {
                    executor.submit(() -> {
                        // Busy work
                        double result = 0;
                        for (int j = 0; j < 5000; j++) {
                            result += Math.sin(j) * Math.cos(j);
                        }
                        computations.addAndGet((long) result);

                        LockSupport.parkNanos(100_00);
                    });
                }
                executor.shutdown();
                if (!executor.awaitTermination(duration.toMillis() + 5000, TimeUnit.MILLISECONDS)) {
                    System.err.println("Virtual threads did not complete in time.");
                    recording.stop();
                    recording.close();
                    System.exit(1);
                }
            }

            recording.stop();
        }

        System.out.println("Test finished successfully.");
    }
}
