import java.nio.charset.StandardCharsets;
import java.time.Duration;

public class SetThreadBench {
    private static final int NUM_BENCHMARK_LOOPS = 10;
    private static final int NUM_TIMED_ITERATIONS = 100000;

    // Short names (<= 15 characters)
    private static final String[] SHORT_TEST_NAMES = {
            "",
            "a",
            "short",
            "shortname",
            "123456789012345", // 15 chars - short
    };

    // Long names (> 15 characters)
    private static final String[] LONG_TEST_NAMES = {
            "1234567890123456",
            "MutationStage-199",
            "RequestResponseStage-196",
            "Messaging-EventLoop-3-20",
            "Native-Transport-Requests-213"
    };

    private static final record BenchmarkSet(String flagName, String label, String[] names) {
    }

    private static final BenchmarkSet[] BENCHMARK_SETS = {
            new BenchmarkSet("short", "SHORT NAMES <= 15 bytes", SHORT_TEST_NAMES),
            new BenchmarkSet("long", "LONG NAMES > 15 bytes", LONG_TEST_NAMES),
    };

    private static final void exitWithUsage() {
        var allFlags = new StringBuilder();
        for (BenchmarkSet set : BENCHMARK_SETS) {
            if (allFlags.length() > 0) {
                allFlags.append("|");
            }
            allFlags.append(set.flagName());
        }
        var usageMsg = String.format("Usage: java SetThreadBench [%s]", allFlags.toString());

        System.err.println(usageMsg);
        System.exit(1);
    }

    public static void main(String[] args) {
        // sanity check the test data
        for (String short_name : BENCHMARK_SETS[0].names()) {
            assert short_name.length() <= 15;
            assert short_name.getBytes(StandardCharsets.UTF_8).length == short_name.length();
        }
        for (String long_name : BENCHMARK_SETS[1].names()) {
            assert long_name.length() > 15;
            assert long_name.getBytes(StandardCharsets.UTF_8).length == long_name.length();
        }

        // Parse the command line argument (if any)
        String requestedFlag = null;
        if (args.length > 1) {
            exitWithUsage();
            return;
        }
        BenchmarkSet requestedSet = null;
        if (args.length > 0) {
            requestedFlag = args[0].toLowerCase();

            for (BenchmarkSet set : BENCHMARK_SETS) {
                if (set.flagName().equals(requestedFlag)) {
                    requestedSet = set;
                    break;
                }
            }

            if (requestedSet == null) {
                exitWithUsage();
                return;
            }
        }

        BenchmarkSet[] setsToRun = BENCHMARK_SETS;
        if (requestedFlag != null) {
            setsToRun = new BenchmarkSet[] { requestedSet };
        }
        for (BenchmarkSet set : setsToRun) {
            System.out.println("\n### " + set.label());
            for (int i = 0; i < NUM_BENCHMARK_LOOPS; i++) {
                var duration = benchmark(set.names());
                var msg = String.format("  iteration %d duration: %s", i, friendlyDuration(duration));
                System.out.println(msg);
            }
        }
    }

    private static final String friendlyDuration(Duration duration) {
        double millis = duration.toNanos() / 1_000_000.0;
        return String.format("%.1f ms", millis);
    }

    private static Duration benchmark(String[] names) {
        long startTime = System.nanoTime();

        for (int i = 0; i < NUM_TIMED_ITERATIONS; i++) {
            for (String name : names) {
                Thread.currentThread().setName(name);
            }
        }

        long endTime = System.nanoTime();
        long duration = endTime - startTime;
        return Duration.ofNanos(duration);
    }
}
