/*
 * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

import java.io.FileInputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.security.KeyStore;
import java.security.Principal;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509ExtendedKeyManager;
import javax.net.ssl.X509KeyManager;


public class HandShakeALPN {
    static final char[] passwd = "passphrase".toCharArray();
    static final boolean debug = false;

    static final String keyFilename = "/Users/javase/Documents/workspace/ALPN_test/src/etc/keystore";
    static final String trustFilename = "/Users/javase/Documents/workspace/ALPN_test/src/etc/truststore";
    volatile static boolean serverReady = false;
    volatile int serverPort = 0;
    volatile Exception serverException = null;
    volatile Exception clientException = null;

    Thread serverThread = null;

    public static void main(String[] args) throws Exception {
        if (debug) {
            System.setProperty("javax.net.debug", "all");
        }

        // start the test
        new HandShakeALPN();
    }

    SSLContext getSSLContext() throws Exception {
        SSLContext ctx = SSLContext.getInstance("TLS");

        KeyStore keyKS = KeyStore.getInstance("JKS");
        keyKS.load(new FileInputStream(keyFilename), passwd);

        KeyStore trustKS = KeyStore.getInstance("JKS");
        trustKS.load(new FileInputStream(trustFilename), passwd);

        KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
        kmf.init(keyKS, passwd);

        KeyManager[] kms = kmf.getKeyManagers();
        if (!(kms[0] instanceof X509ExtendedKeyManager)) {
            throw new Exception("kms[0] not X509ExtendedKeyManager");
        }

        kms = new KeyManager[] {
                new MyX509KeyManager((X509KeyManager) kms[0]) };

        TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
        tmf.init(trustKS);
        TrustManager[] tms = tmf.getTrustManagers();

        ctx.init(kms, tms, null);

        return ctx;
    }

    /*
     * Define the server side of the test.
     *
     * If the server prematurely exits, serverReady will be set to true to avoid
     * infinite hangs.
     */
    void doServerSide() throws Exception {
        SSLServerSocketFactory sslssf = getSSLContext()
                .getServerSocketFactory();
        SSLServerSocket sslServerSocket = (SSLServerSocket) sslssf
                .createServerSocket(serverPort);
        sslServerSocket.setNeedClientAuth(true);

        serverPort = sslServerSocket.getLocalPort();

        /*
         * Signal Client, we're ready for his connect.
         */
        serverReady = true;

        try (SSLSocket sslSocket = (SSLSocket) sslServerSocket.accept()) {

            SSLParameters sslp = sslSocket.getSSLParameters();

            /*
             * The default ciphersuite ordering from the SSLContext may not
             * reflect "h2" ciphersuites as being preferred, additionally the
             * client may not send them in an appropriate order. We could resort
             * the suite list if so desired.
             */
            String[] suites = sslp.getCipherSuites();
            sslp.setCipherSuites(suites);
            sslp.setUseCipherSuitesOrder(true); // Set server side order

            // Force selection.
            sslp.setApplicationProtocols(new String[] {"h2"});
            sslSocket.setSSLParameters(sslp);
            
            sslSocket.startHandshake();
            
            String ap = sslSocket.getApplicationProtocol();
            System.out.println("Application Protocol: \"" + ap + "\"");

            InputStream sslIS = sslSocket.getInputStream();
            OutputStream sslOS = sslSocket.getOutputStream();

            sslIS.read();
            sslOS.write(85);
            sslOS.flush();
        }

    }

    /*
     * Define the client side of the test.
     *
     * If the server prematurely exits, serverReady will be set to true to avoid
     * infinite hangs.
     */
    void doClientSide() throws Exception {

        /*
         * Wait for server to get started.
         */
        while (!serverReady) {
            Thread.sleep(50);
        }

        SSLSocketFactory sslsf = getSSLContext().getSocketFactory();
        System.out.println("serverPort =" + serverPort);
        try (SSLSocket sslSocket = (SSLSocket) sslsf.createSocket("localhost",
                serverPort)) {

            SSLParameters sslp = sslSocket.getSSLParameters();

            /*
             * The default ciphersuite ordering from the SSLContext may not
             * reflect "h2" ciphersuites as being preferred, additionally the
             * client may not send them in an appropriate order. We could resort
             * the suite list if so desired.
             */
            String[] suites = sslp.getCipherSuites();
            sslp.setCipherSuites(suites);
            sslp.setUseCipherSuitesOrder(true); // Set server side order

            // Force selection.
            sslp.setApplicationProtocols(new String[] { "http/1.1", "h2" });
            sslSocket.setSSLParameters(sslp);
            
            sslSocket.startHandshake();

            /*
             * Check that the resulting connection meets our defined ALPN
             * criteria.  If we were connecting to a non-JSSE implementation,
             * the server might have negotiated something we shouldn't accept.
             *
             * We were expecting H2 from server, let's make sure the
             * conditions match.
             */
              String ap = sslSocket.getApplicationProtocol();
              System.out.println("Application Protocol: \"" + ap + "\"");
              if(!"h2".equals(ap)){
                  throw new RuntimeException ("expected ALPN value = h2, got" + ap);
              }


            InputStream sslIS = sslSocket.getInputStream();
            OutputStream sslOS = sslSocket.getOutputStream();

            sslOS.write(280);
            sslOS.flush();
            sslIS.read();
        }

    }

    /*
     * Primary constructor, used to drive remainder of the test.
     *
     * Fork off the other side, then do your work.
     */
    HandShakeALPN() throws Exception {
        Exception startException = null;
        try {
            startServer();
            startClient();
        } catch (Exception e) {
            startException = e;
        }

        /*
         * Wait for other side to close down.
         */
        if (serverThread != null) {
            serverThread.join();
        }

        /*
         * Check various exception conditions.
         */
        Exception exception = null;

        if ((clientException != null) && (serverException != null)) {
            // If both failed, return the curthread's exception.
            clientException.initCause(serverException);
            exception = clientException;
        } else if (clientException != null) {
            exception = clientException;
        } else if (serverException != null) {
            exception = serverException;
        } else if (startException != null) {
            exception = startException;
        }

        if (exception != null) {
            if (exception != startException && startException != null) {
                exception.addSuppressed(startException);
            }

            throw exception;
        }

        // Fall-through: no exception to throw!
    }

    void startServer() throws Exception {
        serverThread = new Thread() {
            @Override
            public void run() {
                try {
                    doServerSide();
                } catch (Exception e) {
                    /*
                     * Our server thread just died.
                     *
                     * Release the client, if not active already...
                     */
                    System.err.println("Server died...");
                    serverReady = true;
                    serverException = e;
                }
            }
        };
        serverThread.start();
    }

    void startClient() throws Exception {
        try {
            doClientSide();
        } catch (Exception e) {
            clientException = e;
        }

    }
}

class MyX509KeyManager implements X509KeyManager {

    X509KeyManager km;

    MyX509KeyManager(X509KeyManager km) {
        this.km = km;
    }

    public String[] getClientAliases(String keyType, Principal[] issuers) {
        return km.getClientAliases(keyType, issuers);
    }

    public String chooseClientAlias(String[] keyType, Principal[] issuers,
            Socket socket) {     
        String ap = ((SSLSocket)socket).getHandshakeApplicationProtocol();
        if(!"h2".equals(ap)){
            throw new RuntimeException ("expected ALPN value = h2, got " +ap);
        }
        return km.chooseClientAlias(keyType, issuers, socket);
    }

    public String[] getServerAliases(String keyType, Principal[] issuers) {
        return km.getServerAliases(keyType, issuers);
    }

    public String chooseServerAlias(String keyType, Principal[] issuers,
            Socket socket) {
        String ap = ((SSLSocket)socket).getHandshakeApplicationProtocol();
        if(!"h2".equals(ap)){
            throw new RuntimeException ("expected ALPN value = h2, got " +ap);
        }
        return km.chooseServerAlias(keyType, issuers, socket);
     
    }

    public X509Certificate[] getCertificateChain(String alias) {
        return km.getCertificateChain(alias);
    }

    public PrivateKey getPrivateKey(String alias) {
        return km.getPrivateKey(alias);
    }
}
