package dk.alexandra.fresco.suite.spdz2k.protocols.computations;

import dk.alexandra.fresco.commitment.HashBasedCommitment;
import dk.alexandra.fresco.framework.DRes;
import dk.alexandra.fresco.framework.MaliciousException;
import dk.alexandra.fresco.framework.builder.Computation;
import dk.alexandra.fresco.framework.builder.numeric.ProtocolBuilderNumeric;
import dk.alexandra.fresco.framework.network.serializers.ByteSerializer;
import dk.alexandra.fresco.framework.util.Drbg;
import dk.alexandra.fresco.framework.util.Pair;
import dk.alexandra.fresco.suite.spdz2k.datatypes.CompUInt;
import dk.alexandra.fresco.suite.spdz2k.datatypes.CompUIntConverter;
import dk.alexandra.fresco.suite.spdz2k.datatypes.CompUIntFactory;
import dk.alexandra.fresco.suite.spdz2k.datatypes.Spdz2kSInt;
import dk.alexandra.fresco.suite.spdz2k.datatypes.UInt;
import dk.alexandra.fresco.suite.spdz2k.resource.Spdz2kResourcePool;
import dk.alexandra.fresco.suite.spdz2k.resource.storage.Spdz2kDataSupplier;
import dk.alexandra.fresco.suite.spdz2k.resource.storage.Spdz2kOpenedValueStore;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

/* loaded from: input_file:dk/alexandra/fresco/suite/spdz2k/protocols/computations/Spdz2kMacCheckComputation.class */
public class Spdz2kMacCheckComputation<HighT extends UInt<HighT>, LowT extends UInt<LowT>, PlainT extends CompUInt<HighT, LowT, PlainT>> implements Computation<Void, ProtocolBuilderNumeric> {
    private final CompUIntConverter<HighT, LowT, PlainT> converter;
    private final Spdz2kOpenedValueStore<PlainT> openedValueStore;
    private final ByteSerializer<PlainT> serializer;
    private final Spdz2kDataSupplier<PlainT> supplier;
    private List<PlainT> randomCoefficients;
    private ByteSerializer<HashBasedCommitment> commitmentSerializer;
    private final int noOfParties;
    private final Drbg localDrbg;

    public Spdz2kMacCheckComputation(Spdz2kResourcePool<PlainT> spdz2kResourcePool, CompUIntConverter<HighT, LowT, PlainT> compUIntConverter) {
        this.openedValueStore = spdz2kResourcePool.getOpenedValueStore();
        this.converter = compUIntConverter;
        this.serializer = spdz2kResourcePool.getPlainSerializer();
        this.supplier = spdz2kResourcePool.getDataSupplier();
        this.randomCoefficients = sampleCoefficients(spdz2kResourcePool.getRandomGenerator(), spdz2kResourcePool.getFactory(), this.openedValueStore.getNumPending());
        this.commitmentSerializer = spdz2kResourcePool.getCommitmentSerializer();
        this.noOfParties = spdz2kResourcePool.getNoOfParties();
        this.localDrbg = spdz2kResourcePool.getLocalRandomGenerator();
    }

    public DRes<Void> buildComputation(ProtocolBuilderNumeric protocolBuilderNumeric) {
        Pair<List<Spdz2kSInt<PlainT>>, List<PlainT>> peekValues = this.openedValueStore.peekValues();
        List list = (List) peekValues.getFirst();
        List list2 = (List) peekValues.getSecond();
        PlainT secretSharedKey = this.supplier.getSecretSharedKey();
        CompUInt compUInt = (CompUInt) UInt.innerProduct(list2, this.randomCoefficients);
        Spdz2kSInt<PlainT> nextRandomElementShare = this.supplier.getNextRandomElementShare();
        return protocolBuilderNumeric.seq(protocolBuilderNumeric2 -> {
            return this.noOfParties > 2 ? new BroadcastComputation((List<byte[]>) list.stream().map(spdz2kSInt -> {
                return spdz2kSInt.getShare().getLeastSignificant().toByteArray();
            }).collect(Collectors.toList())).buildComputation((BroadcastComputation) protocolBuilderNumeric2) : () -> {
                return null;
            };
        }).seq((protocolBuilderNumeric3, list3) -> {
            return computePValues(protocolBuilderNumeric3, list, nextRandomElementShare);
        }).seq((protocolBuilderNumeric4, list4) -> {
            return computeZValues(protocolBuilderNumeric4, list, secretSharedKey, compUInt, nextRandomElementShare, list4);
        }).seq((protocolBuilderNumeric5, list5) -> {
            if (!((CompUInt) UInt.sum(this.serializer.deserializeList(list5))).isZero()) {
                throw new MaliciousException("Mac check failed");
            }
            this.openedValueStore.clear();
            return null;
        });
    }

    private HighT computePj(PlainT plaint, PlainT plaint2) {
        return (HighT) computeDifference(plaint).multiply(plaint2.getLeastSignificantAsHigh());
    }

    private DRes<List<byte[]>> computePValues(ProtocolBuilderNumeric protocolBuilderNumeric, List<Spdz2kSInt<PlainT>> list, Spdz2kSInt<PlainT> spdz2kSInt) {
        UInt computePj = computePj(list.get(0).getShare(), this.randomCoefficients.get(0));
        for (int i = 1; i < list.size(); i++) {
            computePj = computePj.add(computePj(list.get(i).getShare(), this.randomCoefficients.get(i)));
        }
        return new BroadcastComputation(computePj.add(spdz2kSInt.getShare().getLeastSignificantAsHigh()).toByteArray()).buildComputation((BroadcastComputation) protocolBuilderNumeric);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private DRes<List<byte[]>> computeZValues(ProtocolBuilderNumeric protocolBuilderNumeric, List<Spdz2kSInt<PlainT>> list, PlainT plaint, PlainT plaint2, Spdz2kSInt<PlainT> spdz2kSInt, List<byte[]> list2) {
        CompUInt createFromHigh = this.converter.createFromHigh(UInt.sum((List) this.serializer.deserializeList(list2).stream().map((v0) -> {
            return v0.getLeastSignificantAsHigh();
        }).collect(Collectors.toList())));
        return new Spdz2kCommitmentComputation(this.commitmentSerializer, this.serializer.serialize((CompUInt) ((CompUInt) ((CompUInt) ((CompUInt) plaint.multiply(plaint2)).subtract((CompUInt) UInt.innerProduct((List) list.stream().map((v0) -> {
            return v0.getMacShare();
        }).collect(Collectors.toList()), this.randomCoefficients))).subtract(((CompUInt) createFromHigh.multiply(plaint)).shiftLowIntoHigh())).add(spdz2kSInt.getMacShare().shiftLowIntoHigh())), this.noOfParties, this.localDrbg).buildComputation(protocolBuilderNumeric);
    }

    private List<PlainT> sampleCoefficients(Drbg drbg, CompUIntFactory<PlainT> compUIntFactory, int i) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            byte[] bArr = new byte[compUIntFactory.getHighBitLength() / 8];
            drbg.nextBytes(bArr);
            arrayList.add(compUIntFactory.createFromBytes(bArr));
        }
        return arrayList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private HighT computeDifference(PlainT plaint) {
        return (HighT) ((CompUInt) this.converter.createFromLow(plaint.getLeastSignificant()).subtract(plaint)).getMostSignificant();
    }
}
