import javax.net.ssl.*;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Random;
import java.util.concurrent.*;

public class Main {

    public static void main(String[] args)  throws Exception {
        //Some insanity to get a working keystore/truststore locally.
        TrustManager tm = new X509TrustManager() {
            public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
            }

            public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
            }

            public X509Certificate[] getAcceptedIssuers() {
                return null;
            }
        };

        KeyStore keystore = KeyStore.getInstance("pkcs12");
        char[] password = "password".toCharArray();
        keystore.load( new FileInputStream( new File( "C:\\Users\\TONGWAN\\Documents\\JI-9074360\\server.p12") ), password );
        KeyManagerFactory kmf = KeyManagerFactory.getInstance( KeyManagerFactory.getDefaultAlgorithm());
        kmf.init( keystore, password);
        SSLContext sslContext = null;
        try {
            sslContext = SSLContext.getInstance("tls");
            sslContext.init( kmf.getKeyManagers(), new TrustManager[] { tm }, null);
        } catch (NoSuchAlgorithmException e1) {
            e1.printStackTrace();
            System.exit( 1 );
        } catch (KeyManagementException e) {
            e.printStackTrace();
        }

//Start the client first.
        TestClient client = new TestClient(sslContext);
        Thread t = new Thread(client);
        t.start();

        TestServer server = new TestServer( sslContext );
        Thread s = new Thread( server );
        s.start();

        Thread.sleep( 5000 );
        System.out.println("Slow client, better kick them off...");
        server.closeConnection();

        System.out.println("Connection should be successfully closed now.");

    }

    private static class TestServer implements Runnable {

        boolean running = true;
        SSLContext sslContext;
        SSLSocket outConn;

        public TestServer( SSLContext ctx ) {
            this.sslContext = ctx;

        }

        public void closeConnection() throws IOException
        {
            outConn.close();
        }

        public void run() {

            SSLServerSocket serverSock = null;
            try {
                serverSock = (SSLServerSocket) sslContext
                        .getServerSocketFactory().createServerSocket(12345);
                serverSock.setNeedClientAuth(false);
                serverSock.setWantClientAuth(false);

//Accept the connection
                outConn = (SSLSocket) serverSock.accept();
                //Uncomment this line for it to succeed.
//outConn.setSoLinger( true, 0);
                System.out.println("Connection SO_LINGER: " + outConn.getSoLinger() );

//Write some random data to the socket

                ExecutorService executor = Executors.newFixedThreadPool(1);
                Callable<Void> task = new Callable<Void>() {

                    public Void call() throws Exception {
                        Random rd = new Random();
                        byte[] arr = new byte[3000000];
                        rd.nextBytes(arr);
                        outConn.getOutputStream().write(arr);
                        return null;
                    }

                };

                Future<Void> serverFuture = executor.submit(task);
                try {
                    serverFuture.get( 20000, TimeUnit.MILLISECONDS );
                    System.out.println("Finished writing data.");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } catch (ExecutionException e) {
                    e.printStackTrace();
                } catch (TimeoutException e) {
                    System.out.println("Timed out trying to write data.");
                }


            } catch (IOException e) {
                e.printStackTrace();
            }
            finally {
                if ( serverSock != null ) {
                    try {
                        serverSock.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }

        }

    }

    private static class TestClient implements Runnable {
        boolean running = true;
        SSLContext sslContext;

        public TestClient(SSLContext sslContext) {
            this.sslContext = sslContext;
        }

        public void run() {
            int i = 0;
            while ( running ) {
                SSLSocket sslSock = null;
                try {
                    sslSock = (SSLSocket) sslContext.getSocketFactory().createSocket();
                    sslSock.setUseClientMode(true);

                    sslSock.connect( new InetSocketAddress( "localhost", 12345), 2000);
//Read very slowly off the socket
                    sslSock.getInputStream().read();
                    while ( sslSock.isConnected() && i < 30 )
                    {
                        try {
                            Thread.sleep( 1000 );
                            i++;
                        } catch (InterruptedException e) {
// TODO Auto-generated catch block
                            e.printStackTrace();
                        }
                    }
                    if ( i >= 30 ) {
                        System.out.println( "Client thread timing out now.");
                    }
                } catch (IOException e) {
                    System.out.println("Client exception: " + e.getMessage());
                }
                finally {
                    if ( sslSock != null )
                    {
                        try {
                            sslSock.close();
                        } catch (IOException e) {
//Ignore
                        }
                    }
                }
            }

        }

    }
}
