package com.java.tlsfun.tests.server.jsse.poc;

import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocket;

/**
 * This PoC is for figuring out how a TLS server responds to an empty TLS plaintext messages.
 *
 * RFC 5246 says the following about empty TLS plaintexts (section 6.2.1):
 *
 *      Implementations MUST NOT send zero-length fragments of Handshake, Alert, or
 *      ChangeCipherSpec content types. Zero-length fragments of Application data MAY
 *      be sent as they are potentially useful as a traffic analysis countermeasure.
 *
 * So, an empty TLS plaintext is forbidden, and should be rejected.
 *
 * If no parameters provided to this PoC, it starts a local TLS server,
 * and sends empty TLS plaintexts to it.
 *
 * If hostname and port provided, this PoC sends empty TLS plaintexts to the specified address.
 */
public class EmptyTLSPlaintext {

    private static final int SERVER_TIMEOUT = 5 * 1000; // 5 seconds
    private static final int CLIENT_DELAY = 3 * 1000; // 3 seconds

    // an empty TLS message
    private static final byte HANDSHAKE_MESSAGE = 22;
    private static final byte[] TLS12_VERSION = {3, 3};
    private static final byte[] ZERO_LENGTH = {0, 0};

    /*
     * Command line options:
     *      args[0] - hostname  (optional)
     *      args[1] - port      (optional)
     */
    public static void main(String[] args) throws Exception {
        if (args.length > 1) {
            String host = args[0];
            int port = Integer.parseInt(args[1]);
            runClient(host, port);
        } else {
            try (SimpleServer server = start()) {
                sleep(CLIENT_DELAY); // 3 seconds
                runClient("localhost", server.getPort());
            }
        }
    }

    // runs a TLS client that sends an empty TLS plaintext messages
    private static void runClient(String host, int port) throws IOException {
        System.out.printf("connect to %s:%d%n", host, port);
        try (Socket socket = new Socket(host, port)) {
            OutputStream out = new BufferedOutputStream(socket.getOutputStream());
            while (true) {
                System.out.println("send an empty TLS plaintext");
                out.write(HANDSHAKE_MESSAGE);
                out.write(TLS12_VERSION);
                out.write(ZERO_LENGTH);
                out.flush();
                sleep(SERVER_TIMEOUT / 2);
            }
        }
    }

    // sleep for the specified time in milliseconds
    private static void sleep(long delay) {
        try {
            Thread.sleep(delay);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    // starts a TLS server in a separate thread
    private static SimpleServer start() throws IOException {
        SSLServerSocketFactory ssf = (SSLServerSocketFactory)
                SSLServerSocketFactory.getDefault();
        SSLServerSocket ssocket = (SSLServerSocket) ssf.createServerSocket(0);
        ssocket.setSoTimeout(SERVER_TIMEOUT);

        SimpleServer server = new SimpleServer(ssocket);
        new Thread(server, "TLS server").start();
        return server;
    }

    // simple TLS server based on SSLSocket
    private static class SimpleServer implements Runnable, AutoCloseable {

        private final SSLServerSocket ssocket;
        private volatile boolean running = false;

        SimpleServer(SSLServerSocket ssocket) {
            this.ssocket = ssocket;
        }

        @Override
        public void run() {
            System.out.println("started on port: " + getPort());
            running = true;
            try (SSLSocket socket = (SSLSocket) ssocket.accept()) {
                System.out.println("accepted");
                socket.setSoTimeout(SERVER_TIMEOUT);
                InputStream in = socket.getInputStream();
                OutputStream out = socket.getOutputStream();
                int b = in.read();
                System.out.println("send data: " + b);
                out.write(b);
                out.flush();
                socket.getSession().invalidate();
            } catch (Throwable e) {
                System.out.println("exception: " + e);
                e.printStackTrace(System.out);
            }
            System.out.println("finished");
            running = false;
        }

        public int getPort() {
            return ssocket.getLocalPort();
        }

        public boolean isRunning() {
            return running;
        }

        void stop() {
            if (!ssocket.isClosed()) {
                try {
                    System.out.println("close socket");
                    ssocket.close();
                } catch (IOException e) {
                    System.out.println("socket closed: " + e);
                }
            }
        }

        @Override
        public void close() {
            stop();
        }
    }
}
