import javax.net.ssl.*;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.cert.Certificate;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;

public class Tls13SocketProbe {

    public static void main(String[] args) throws Exception {
        if (args.length < 1 || args.length > 3) {
            System.out.println("Usage: java Tls13SocketProbe <host> [port=443] [iterations=100]");
            System.exit(1);
        }

        final String host = args[0];
        final int port = (args.length >= 2) ? Integer.parseInt(args[1]) : 443;
        final int iterations = (args.length >= 3) ? Integer.parseInt(args[2]) : 100;

        // Build TLS context with default trust store (you can load a custom truststore if needed)
        SSLContext ctx = SSLContext.getInstance("TLS");
        ctx.init(null, null, null);

        SSLSocketFactory factory = ctx.getSocketFactory();

        int successes = 0, failures = 0;

        for (int i = 1; i <= iterations; i++) {
            Instant start = Instant.now();
            try (SSLSocket socket = (SSLSocket) factory.createSocket()) {

                // Optional: set TCP connect timeout (milliseconds)
                socket.connect(new InetSocketAddress(host, port), 10_000);

                // Force TLSv1.3 only
                SSLParameters params = socket.getSSLParameters();
                params.setProtocols(new String[]{"TLSv1.3"});
                // Optionally restrict cipher suites to common TLS 1.3 ciphers:
                // params.setCipherSuites(new String[] {
                //     "TLS_AES_128_GCM_SHA256",
                //     "TLS_AES_256_GCM_SHA384",
                //     "TLS_CHACHA20_POLY1305_SHA256"
                // });
                socket.setSSLParameters(params);

                // SNI: set server name for proper cert selection on multi-tenant frontends
                SNIHostName sniHostName = new SNIHostName(host);
                SSLParameters sniParams = socket.getSSLParameters();
                sniParams.setServerNames(Arrays.asList(sniHostName));
                socket.setSSLParameters(sniParams);

                // Trigger handshake
                socket.startHandshake();

                SSLSession session = socket.getSession();
                String protocol = session.getProtocol();
                String cipher = session.getCipherSuite();

                Certificate[] peerCerts = session.getPeerCertificates();
                String peerSubject = (peerCerts != null && peerCerts.length > 0)
                        ? peerCerts[0].toString()
                        : "<no peer cert>";

                Duration dur = Duration.between(start, Instant.now());
                System.out.printf(
                        "#%d SUCCESS  host=%s port=%d  protocol=%s cipher=%s  time=%dms%n",
                        i, host, port, protocol, cipher, dur.toMillis()
                );

                // Uncomment to print the first certificate subject
                // System.out.println("Peer cert[0]: " + peerSubject);

                successes++;

            } catch (SSLHandshakeException e) {
                failures++;
                System.out.printf("#%d FAIL SSLHandshakeException: %s%n", i, e.getMessage());
            } catch (IOException e) {
                failures++;
                System.out.printf("#%d FAIL IOException: %s%n", i, e.getMessage());
            } catch (Exception e) {
                failures++;
                System.out.printf("#%d FAIL Exception: %s%n", i, e.getMessage());
            }

            // Small pause to avoid hammering (adjust or remove for true stress)
            try { Thread.sleep(50); } catch (InterruptedException ignored) {}
        }

        System.out.printf("%nRun complete: successes=%d failures=%d failureRate=%.2f%%%n",
                successes, failures, (failures * 100.0) / Math.max(1, (successes + failures)));
    }
}

