/*
 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
 * ORACLE PROPRIETARY/CONFIDENTIAL. Use is subject to license terms.
 */

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.Socket;
import java.util.logging.Level;

public class Socks4ConnectionHandler extends AbstractConnectionHandler {

    private static int bufferSize = 512;

    /**
     * Delay (in seconds)
     */
    private int delay = 0;

    public Socks4ConnectionHandler(Socket socket) {
        super(socket);
    }

    /**
     * Set delay (in seconds)
     */
    public void setDelay(int delay) {
        this.delay = delay;
    }

    @Override
    protected void handleConnection(Socket socket) throws IOException {
        try {
            logger.fine("Handle client connection");
            InputStream clientInput = new BufferedInputStream(socket.getInputStream());
            OutputStream clientOutput = new BufferedOutputStream(socket.getOutputStream());
            DataInputStream clientDataInput = new DataInputStream(clientInput);
            checkVersionNumber(clientDataInput);
            byte commandCode = readCommandCode(clientDataInput);
            int port = readPortNumber(clientDataInput);
            InetAddress ip = readIPAddress(clientDataInput);
            readUserIdString(clientDataInput);

            if (checkConnection(ip, port)) {
                sendRequestGrantedMessage(clientOutput);
                if (commandCode == Constants.ESTABLISH_CONNECTION) {
                    establishConnection(clientInput, clientOutput, ip, port);
                } else if (commandCode == Constants.ESTABLISH_PORT_BINDING) {
                    establishPortBinding(clientDataInput, clientOutput, ip, port);
                } else {
                    throw new Socks4Exception("Unexpected command code: " + commandCode);
                }
            } else {
                sendRequestRejectedMessage(clientOutput);
            }

        } catch (Socks4Exception e) {
            logger.log(Level.FINE, "Got Socks4Exception", e);
        }
    }

    /**
     * Read and check version number. It must be 0x04
     *
     * @param clientInput
     */
    private void checkVersionNumber(DataInputStream clientInput) throws IOException, Socks4Exception {
        byte versionNumber = clientInput.readByte();
        logger.log(Level.FINE, "Version number: {0}", versionNumber);
        if (versionNumber != Constants.VERSION_NUMBER) {
            throw new Socks4Exception(
                    "Version number is " + versionNumber + ", but it must be " + Constants.VERSION_NUMBER);
        }
    }

    private byte readCommandCode(DataInputStream clientInput) throws IOException {
        byte commandCode = clientInput.readByte();
        logger.log(Level.FINE, "Command code: {0}", commandCode);
        return commandCode;
    }

    private int readPortNumber(DataInputStream clientInput) throws IOException {
        int port = clientInput.readUnsignedShort();
        logger.log(Level.FINE, "Port: {0}", port);
        return port;
    }

    private InetAddress readIPAddress(DataInputStream clientInput) throws IOException, Socks4Exception {
        byte[] addr = new byte[4];
        clientInput.readFully(addr, 0, 4);
        InetAddress ip = InetAddress.getByAddress(addr);
        logger.log(Level.FINE, "IP address: {0}", ip.toString());
        return ip;
    }

    private String readUserIdString(DataInputStream clientInput) throws IOException, Socks4Exception {
        byte b;
        StringBuilder sb = new StringBuilder();
        while ((b = clientInput.readByte()) != 0x00) {
            sb.append(b);
        }

        String userId = sb.toString();
        logger.log(Level.FINE, "User ID: {0}", userId);

        return userId;
    }

    private void establishConnection(InputStream clientInput, OutputStream clientOutput, InetAddress ip, int port) {
        Utils.sleep(delay / 2);

        logger.log(Level.FINE, "Establish connection with {0}:{1}", new Object[] { ip.toString(), port });
        Socket remoteSocket = new Socket(Proxy.NO_PROXY);
        logger.fine("Redirecting has been started");

        InputStream remoteInput = null;
        OutputStream remoteOutput = null;

        boolean hasClientData = false;
        try {
            byte[] data;
            while (true) {

                // wait for the client sends some data
                do {
                    data = readDataFromInputStream(clientInput);
                    if (data == null || data.length == 0) {
                        try {
                            Thread.sleep(500);
                        } catch (InterruptedException e) {
                            logger.log(Level.WARNING, "Got InterruptedException", e);
                        }
                    } else {
                        hasClientData = true;
                    }
                } while (!hasClientData);

                if (data != null && data.length != 0) {
                    logger.log(Level.FINE, "Got data from client: {0}", new String(data));
                }

                if (checkData(data)) {
                    if (!remoteSocket.isConnected()) {
                        remoteSocket.connect(new InetSocketAddress(ip, port));
                    }

                    if (remoteOutput == null) {
                        remoteOutput = new BufferedOutputStream(remoteSocket.getOutputStream());
                    }

                    remoteOutput.write(data);
                    remoteOutput.flush();
                } else {
                    logger.log(Level.FINE, "Client data declined");
                    break;
                }

                if (remoteInput == null) {
                    remoteInput = new BufferedInputStream(remoteSocket.getInputStream());
                }

                data = readDataFromInputStream(remoteInput);
                if (data == null || data.length == 0) {
                    try {
                        Thread.sleep(500);
                    } catch (InterruptedException e) {
                        logger.log(Level.WARNING, "Got InterruptedException", e);
                    }
                }

                if (data != null && data.length != 0) {
                    logger.log(Level.FINE, "Got data from remote host: {0}", new String(data));
                }

                if (checkData(data)) {
                    clientOutput.write(data);
                    clientOutput.flush();
                } else {
                    logger.info("Remote host data declined");
                    break;
                }
            }

            Utils.sleep(delay / 2);

            logger.fine("Close connection for remote host");
            remoteSocket.close();

        } catch (IOException e) {
            logger.log(Level.SEVERE, "Got I/O exception", e);
        }

        logger.log(Level.FINE, "Connection with {0}:{1} has been finished", new Object[] { ip.toString(), port });
    }

    private void establishPortBinding(InputStream clientInput, OutputStream clientOutput, InetAddress ip, int port) {
        logger.log(Level.FINE, "Establish port binding on {0}:{1}", new Object[] { ip.toString(), port });
        throw new UnsupportedOperationException("Establishing port binding is not supported");
    }

    private void sendRequestGrantedMessage(OutputStream clientOutput) throws IOException {
        byte[] response = new byte[8];
        response[1] = Constants.REQUEST_GRANTED;
        clientOutput.write(response);
        clientOutput.flush();
    }

    private void sendRequestRejectedMessage(OutputStream clientOutput) throws IOException {
        byte[] response = new byte[8];
        response[1] = Constants.REQUEST_REJECTED;
        clientOutput.write(response);
        clientOutput.flush();
    }

    protected byte[] readDataFromInputStream(InputStream is) throws IOException {
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        int read;
        byte[] buf = new byte[bufferSize];
        while (true) {
            if (is.available() <= 0)
                break;
            read = is.read(buf, 0, bufferSize);
            if (read <= 0)
                break;
            bos.write(buf, 0, read);
        }
        return bos.toByteArray();
    }
}
