package code.ponfee.commons.jce.sm;

import code.ponfee.commons.jce.ECParameters;
import code.ponfee.commons.util.Bytes;
import code.ponfee.commons.util.SecureRandoms;
import java.io.ByteArrayOutputStream;
import java.io.Serializable;
import java.math.BigInteger;
import java.util.Arrays;
import org.bouncycastle.math.ec.ECPoint;

/* loaded from: input_file:code/ponfee/commons/jce/sm/SM2KeyExchanger.class */
public class SM2KeyExchanger implements Serializable {
    private static final long serialVersionUID = 8553046425593791291L;
    private static final BigInteger TWO = BigInteger.valueOf(2);
    private BigInteger rA;
    private ECPoint RA;
    private ECPoint V;
    private byte[] key;
    private final ECParameters ecParam;
    private final BigInteger w;
    private final ECPoint publicKey;
    private final BigInteger privateKey;
    private final byte[] Z;

    /* loaded from: input_file:code/ponfee/commons/jce/sm/SM2KeyExchanger$TransportEntity.class */
    public static class TransportEntity implements Serializable {
        private static final long serialVersionUID = 3657694935421411649L;
        private final byte[] R;
        private final byte[] S;
        private final byte[] Z;
        private final byte[] K;

        TransportEntity(byte[] bArr, byte[] bArr2, byte[] bArr3, ECPoint eCPoint) {
            this(bArr, bArr2, bArr3, eCPoint.getEncoded(false));
        }

        TransportEntity(byte[] bArr, byte[] bArr2, byte[] bArr3, byte[] bArr4) {
            this.R = bArr;
            this.S = bArr2;
            this.Z = bArr3;
            this.K = bArr4;
        }

        public byte[] getR() {
            return this.R;
        }

        public byte[] getS() {
            return this.S;
        }

        public byte[] getZ() {
            return this.Z;
        }

        public byte[] getK() {
            return this.K;
        }
    }

    public SM2KeyExchanger(ECPoint eCPoint, BigInteger bigInteger) {
        this(null, eCPoint, bigInteger, ECParameters.SM2_BEST);
    }

    public SM2KeyExchanger(byte[] bArr, ECPoint eCPoint, BigInteger bigInteger) {
        this(bArr, eCPoint, bigInteger, ECParameters.SM2_BEST);
    }

    public SM2KeyExchanger(ECPoint eCPoint, BigInteger bigInteger, ECParameters eCParameters) {
        this(null, eCPoint, bigInteger, eCParameters);
    }

    public SM2KeyExchanger(byte[] bArr, ECPoint eCPoint, BigInteger bigInteger, ECParameters eCParameters) {
        this.ecParam = eCParameters;
        this.w = TWO.pow(((int) Math.ceil((eCParameters.n.bitLength() * 1.0d) / 2.0d)) - 1);
        this.publicKey = eCPoint;
        this.privateKey = bigInteger;
        this.Z = SM2.calcZ(SM3Digest.getInstance(), eCParameters, bArr, eCPoint);
    }

    public TransportEntity step1PartA() {
        this.rA = SecureRandoms.random(this.ecParam.n);
        this.RA = this.ecParam.pointG.multiply(this.rA).normalize();
        return new TransportEntity(this.RA.getEncoded(false), (byte[]) null, this.Z, this.publicKey);
    }

    /* JADX WARN: Type inference failed for: r2v10, types: [byte[], byte[][]] */
    public TransportEntity step2PartB(TransportEntity transportEntity) {
        BigInteger random = SecureRandoms.random(this.ecParam.n);
        ECPoint normalize = this.ecParam.pointG.multiply(random).normalize();
        this.rA = random;
        this.RA = normalize;
        BigInteger mod = this.privateKey.add(this.w.add(normalize.getXCoord().toBigInteger().and(this.w.subtract(BigInteger.ONE))).multiply(random)).mod(this.ecParam.n);
        ECPoint normalize2 = this.ecParam.curve.decodePoint(transportEntity.R).normalize();
        ECPoint normalize3 = this.ecParam.curve.decodePoint(transportEntity.K).normalize().add(normalize2.multiply(this.w.add(normalize2.getXCoord().toBigInteger().and(this.w.subtract(BigInteger.ONE)))).normalize()).normalize().multiply(this.ecParam.bcSpec.getH().multiply(mod)).normalize();
        if (normalize3.isInfinity()) {
            throw new IllegalStateException();
        }
        this.V = normalize3;
        byte[] byteArray = normalize3.getXCoord().toBigInteger().toByteArray();
        byte[] byteArray2 = normalize3.getYCoord().toBigInteger().toByteArray();
        this.key = kdf(Bytes.concat(byteArray, new byte[]{byteArray2, transportEntity.Z, this.Z}), 16);
        SM3Digest sM3Digest = SM3Digest.getInstance();
        byte[] digest = digest(sM3Digest, byteArray, transportEntity.Z, this.Z, normalize2, normalize);
        sM3Digest.update((byte) 2);
        sM3Digest.update(byteArray2);
        sM3Digest.update(digest);
        return new TransportEntity(normalize.getEncoded(false), sM3Digest.doFinal(), this.Z, this.publicKey);
    }

    /* JADX WARN: Type inference failed for: r2v11, types: [byte[], byte[][]] */
    public TransportEntity step3PartA(TransportEntity transportEntity) {
        BigInteger mod = this.privateKey.add(this.w.add(this.RA.getXCoord().toBigInteger().and(this.w.subtract(BigInteger.ONE))).multiply(this.rA)).mod(this.ecParam.n);
        ECPoint normalize = this.ecParam.curve.decodePoint(transportEntity.R).normalize();
        ECPoint normalize2 = this.ecParam.curve.decodePoint(transportEntity.K).normalize().add(normalize.multiply(this.w.add(normalize.getXCoord().toBigInteger().and(this.w.subtract(BigInteger.ONE)))).normalize()).normalize().multiply(this.ecParam.bcSpec.getH().multiply(mod)).normalize();
        if (normalize2.isInfinity()) {
            throw new IllegalStateException();
        }
        this.V = normalize2;
        byte[] byteArray = normalize2.getXCoord().toBigInteger().toByteArray();
        byte[] byteArray2 = normalize2.getYCoord().toBigInteger().toByteArray();
        this.key = kdf(Bytes.concat(byteArray, new byte[]{byteArray2, this.Z, transportEntity.Z}), 16);
        SM3Digest sM3Digest = SM3Digest.getInstance();
        byte[] digest = digest(sM3Digest, byteArray, this.Z, transportEntity.Z, this.RA, normalize);
        sM3Digest.update((byte) 2);
        sM3Digest.update(byteArray2);
        sM3Digest.update(digest);
        if (!Arrays.equals(transportEntity.S, sM3Digest.doFinal())) {
            return null;
        }
        byte[] digest2 = digest(sM3Digest, byteArray, this.Z, transportEntity.Z, this.RA, normalize);
        sM3Digest.update((byte) 3);
        sM3Digest.update(byteArray2);
        sM3Digest.update(digest2);
        return new TransportEntity(this.RA.getEncoded(false), sM3Digest.doFinal(), this.Z, this.publicKey);
    }

    public boolean step4PartB(TransportEntity transportEntity) {
        byte[] byteArray = this.V.getXCoord().toBigInteger().toByteArray();
        byte[] byteArray2 = this.V.getYCoord().toBigInteger().toByteArray();
        ECPoint normalize = this.ecParam.curve.decodePoint(transportEntity.R).normalize();
        SM3Digest sM3Digest = SM3Digest.getInstance();
        byte[] digest = digest(sM3Digest, byteArray, transportEntity.Z, this.Z, normalize, this.RA);
        sM3Digest.update((byte) 3);
        sM3Digest.update(byteArray2);
        sM3Digest.update(digest);
        return Arrays.equals(transportEntity.S, sM3Digest.doFinal());
    }

    public byte[] getKey() {
        return this.key;
    }

    private static byte[] digest(SM3Digest sM3Digest, byte[] bArr, byte[] bArr2, byte[] bArr3, ECPoint eCPoint, ECPoint eCPoint2) {
        sM3Digest.reset();
        sM3Digest.update(bArr);
        sM3Digest.update(bArr2);
        sM3Digest.update(bArr3);
        sM3Digest.update(eCPoint.getXCoord().toBigInteger().toByteArray());
        sM3Digest.update(eCPoint.getYCoord().toBigInteger().toByteArray());
        sM3Digest.update(eCPoint2.getXCoord().toBigInteger().toByteArray());
        sM3Digest.update(eCPoint2.getYCoord().toBigInteger().toByteArray());
        return sM3Digest.doFinal();
    }

    private static byte[] kdf(byte[] bArr, int i) {
        int i2 = 1;
        int ceil = (int) Math.ceil((i * 1.0d) / 32.0d);
        SM3Digest sM3Digest = SM3Digest.getInstance();
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        for (int i3 = 1; i3 < ceil; i3++) {
            sM3Digest.update(bArr);
            byte[] doFinal = sM3Digest.doFinal(Bytes.toBytes(i2));
            byteArrayOutputStream.write(doFinal, 0, doFinal.length);
            i2++;
        }
        sM3Digest.update(bArr);
        sM3Digest.update(Bytes.toBytes(i2));
        byte[] doFinal2 = sM3Digest.doFinal();
        int i4 = i & 31;
        byteArrayOutputStream.write(doFinal2, 0, i4 == 0 ? doFinal2.length : i4);
        return byteArrayOutputStream.toByteArray();
    }
}
