/*
 * Decompiled with CFR 0.152.
 */
package io.undertow.protocols.ssl;

import io.undertow.UndertowMessages;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLException;

final class ALPNHackServerHelloExplorer {
    private ALPNHackServerHelloExplorer() {
    }

    static byte[] addAlpnExtensionsToServerHello(byte[] source, String selectedAlpnProtocol) throws SSLException {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        ByteBuffer input = ByteBuffer.wrap(source);
        try {
            ALPNHackServerHelloExplorer.exploreHandshake(input, source.length, new AtomicReference<String>(selectedAlpnProtocol), out);
            int serverHelloLength = out.size() - 4;
            out.write(source, input.position(), input.remaining());
            byte[] data = out.toByteArray();
            data[1] = (byte)(serverHelloLength >> 16 & 0xFF);
            data[2] = (byte)(serverHelloLength >> 8 & 0xFF);
            data[3] = (byte)(serverHelloLength & 0xFF);
            return data;
        }
        catch (AlpnProcessingException e2) {
            return source;
        }
    }

    static byte[] removeAlpnExtensionsFromServerHello(ByteBuffer source, AtomicReference<String> selectedAlpnProtocol) throws SSLException {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        try {
            ALPNHackServerHelloExplorer.exploreHandshake(source, source.remaining(), selectedAlpnProtocol, out);
            int serverHelloLength = out.size() - 4;
            byte[] data = out.toByteArray();
            data[1] = (byte)(serverHelloLength >> 16 & 0xFF);
            data[2] = (byte)(serverHelloLength >> 8 & 0xFF);
            data[3] = (byte)(serverHelloLength & 0xFF);
            return data;
        }
        catch (AlpnProcessingException e2) {
            return null;
        }
    }

    private static void exploreHandshake(ByteBuffer input, int recordLength, AtomicReference<String> selectedAlpnProtocol, ByteArrayOutputStream out) throws SSLException {
        byte handshakeType = input.get();
        if (handshakeType != 2) {
            throw UndertowMessages.MESSAGES.expectedServerHello();
        }
        out.write(handshakeType);
        int handshakeLength = ALPNHackServerHelloExplorer.getInt24(input);
        out.write(0);
        out.write(0);
        out.write(0);
        if (handshakeLength > recordLength - 4) {
            throw UndertowMessages.MESSAGES.multiRecordSSLHandshake();
        }
        int old = input.limit();
        input.limit(handshakeLength + input.position());
        ALPNHackServerHelloExplorer.exploreServerHello(input, selectedAlpnProtocol, out);
        input.limit(old);
    }

    private static void exploreServerHello(ByteBuffer input, AtomicReference<String> alpnProtocolReference, ByteArrayOutputStream out) throws SSLException {
        byte helloMajorVersion = input.get();
        byte helloMinorVersion = input.get();
        out.write(helloMajorVersion);
        out.write(helloMinorVersion);
        for (int i2 = 0; i2 < 32; ++i2) {
            out.write(input.get() & 0xFF);
        }
        ALPNHackServerHelloExplorer.processByteVector8(input, out);
        out.write(input.get() & 0xFF);
        out.write(input.get() & 0xFF);
        out.write(input.get() & 0xFF);
        String existingAlpn = null;
        ByteArrayOutputStream extensionsOutput = null;
        if (input.remaining() > 0) {
            extensionsOutput = new ByteArrayOutputStream();
            existingAlpn = ALPNHackServerHelloExplorer.exploreExtensions(input, extensionsOutput, alpnProtocolReference.get() == null);
        }
        if (existingAlpn != null) {
            if (alpnProtocolReference.get() != null) {
                throw new AlpnProcessingException();
            }
            alpnProtocolReference.set(existingAlpn);
            byte[] existing = extensionsOutput.toByteArray();
            out.write(existing, 0, existing.length);
        } else if (alpnProtocolReference.get() != null) {
            String selectedAlpnProtocol = alpnProtocolReference.get();
            ByteArrayOutputStream alpnBits = new ByteArrayOutputStream();
            alpnBits.write(0);
            alpnBits.write(16);
            int length = 3 + selectedAlpnProtocol.length();
            alpnBits.write(length >> 8 & 0xFF);
            alpnBits.write(length & 0xFF);
            alpnBits.write((length -= 2) >> 8 & 0xFF);
            alpnBits.write(length & 0xFF);
            alpnBits.write(selectedAlpnProtocol.length() & 0xFF);
            for (int i3 = 0; i3 < selectedAlpnProtocol.length(); ++i3) {
                alpnBits.write(selectedAlpnProtocol.charAt(i3) & 0xFF);
            }
            if (extensionsOutput != null) {
                byte[] existing = extensionsOutput.toByteArray();
                int newLength = existing.length - 2 + alpnBits.size();
                existing[0] = (byte)(newLength >> 8 & 0xFF);
                existing[1] = (byte)(newLength & 0xFF);
                try {
                    out.write(existing);
                    out.write(alpnBits.toByteArray());
                }
                catch (IOException e2) {
                    throw new RuntimeException(e2);
                }
            } else {
                int al2 = alpnBits.size();
                out.write(al2 >> 8 & 0xFF);
                out.write(al2 & 0xFF);
                try {
                    out.write(alpnBits.toByteArray());
                }
                catch (IOException e3) {
                    throw new RuntimeException(e3);
                }
            }
        } else if (extensionsOutput != null) {
            byte[] existing = extensionsOutput.toByteArray();
            out.write(existing, 0, existing.length);
        }
    }

    static List<ByteBuffer> extractRecords(ByteBuffer data) {
        ArrayList<ByteBuffer> ret = new ArrayList<ByteBuffer>();
        while (data.hasRemaining()) {
            byte d1 = data.get();
            byte d2 = data.get();
            byte d3 = data.get();
            byte d4 = data.get();
            byte d5 = data.get();
            int length = (d4 & 0xFF) << 8 | d5 & 0xFF;
            byte[] b2 = new byte[length + 5];
            b2[0] = d1;
            b2[1] = d2;
            b2[2] = d3;
            b2[3] = d4;
            b2[4] = d5;
            data.get(b2, 5, length);
            ret.add(ByteBuffer.wrap(b2));
        }
        return ret;
    }

    private static String exploreExtensions(ByteBuffer input, ByteArrayOutputStream extensionOut, boolean removeAlpn) throws SSLException {
        int extLen;
        int length;
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        String ret = null;
        out.write(length >> 8 & 0xFF);
        out.write(length & 0xFF);
        int originalLength = length;
        for (length = ALPNHackServerHelloExplorer.getInt16(input); length > 0; length -= extLen + 4) {
            int extType = ALPNHackServerHelloExplorer.getInt16(input);
            extLen = ALPNHackServerHelloExplorer.getInt16(input);
            if (extType == 16) {
                int vlen = ALPNHackServerHelloExplorer.getInt16(input);
                ret = ALPNHackServerHelloExplorer.readByteVector8(input);
                if (!removeAlpn) {
                    out.write(extType >> 8 & 0xFF);
                    out.write(extType & 0xFF);
                    out.write(extLen >> 8 & 0xFF);
                    out.write(extLen & 0xFF);
                    out.write(vlen >> 8 & 0xFF);
                    out.write(vlen & 0xFF);
                    out.write(ret.length() & 0xFF);
                    for (int i2 = 0; i2 < ret.length(); ++i2) {
                        out.write(ret.charAt(i2) & 0xFF);
                    }
                    continue;
                }
                originalLength -= 6;
                originalLength -= vlen;
                continue;
            }
            out.write(extType >> 8 & 0xFF);
            out.write(extType & 0xFF);
            out.write(extLen >> 8 & 0xFF);
            out.write(extLen & 0xFF);
            ALPNHackServerHelloExplorer.processByteVector(input, extLen, out);
        }
        if (removeAlpn && ret == null) {
            throw new AlpnProcessingException();
        }
        byte[] data = out.toByteArray();
        data[0] = (byte)(originalLength >> 8 & 0xFF);
        data[1] = (byte)(originalLength & 0xFF);
        extensionOut.write(data, 0, data.length);
        return ret;
    }

    private static String readByteVector8(ByteBuffer input) {
        int length = ALPNHackServerHelloExplorer.getInt8(input);
        byte[] data = new byte[length];
        input.get(data);
        return new String(data, StandardCharsets.US_ASCII);
    }

    private static int getInt8(ByteBuffer input) {
        return input.get();
    }

    private static int getInt16(ByteBuffer input) {
        return (input.get() & 0xFF) << 8 | input.get() & 0xFF;
    }

    private static int getInt24(ByteBuffer input) {
        return (input.get() & 0xFF) << 16 | (input.get() & 0xFF) << 8 | input.get() & 0xFF;
    }

    private static void processByteVector8(ByteBuffer input, ByteArrayOutputStream out) {
        int int8 = ALPNHackServerHelloExplorer.getInt8(input);
        out.write(int8 & 0xFF);
        ALPNHackServerHelloExplorer.processByteVector(input, int8, out);
    }

    private static void processByteVector(ByteBuffer input, int length, ByteArrayOutputStream out) {
        for (int i2 = 0; i2 < length; ++i2) {
            out.write(input.get() & 0xFF);
        }
    }

    static ByteBuffer createNewOutputRecords(byte[] newFirstMessage, List<ByteBuffer> records) {
        int length = newFirstMessage.length;
        length += 5;
        for (int i2 = 1; i2 < records.size(); ++i2) {
            ByteBuffer rec = records.get(i2);
            length += rec.remaining();
        }
        byte[] newData = new byte[length];
        ByteBuffer ret = ByteBuffer.wrap(newData);
        ByteBuffer oldHello = records.get(0);
        ret.put(oldHello.get());
        ret.put(oldHello.get());
        ret.put(oldHello.get());
        ret.put((byte)(newFirstMessage.length >> 8 & 0xFF));
        ret.put((byte)(newFirstMessage.length & 0xFF));
        ret.put(newFirstMessage);
        for (int i3 = 1; i3 < records.size(); ++i3) {
            ByteBuffer rec = records.get(i3);
            ret.put(rec);
        }
        ret.flip();
        return ret;
    }

    private static final class AlpnProcessingException
    extends RuntimeException {
        private AlpnProcessingException() {
        }
    }
}

