package dk.alexandra.fresco.suite.spdz;

import dk.alexandra.fresco.framework.DRes;
import dk.alexandra.fresco.framework.TestThreadRunner;
import dk.alexandra.fresco.framework.builder.numeric.DefaultPreprocessedValues;
import dk.alexandra.fresco.framework.builder.numeric.ProtocolBuilderNumeric;
import dk.alexandra.fresco.framework.configuration.NetworkConfiguration;
import dk.alexandra.fresco.framework.configuration.NetworkTestUtils;
import dk.alexandra.fresco.framework.network.AsyncNetwork;
import dk.alexandra.fresco.framework.network.CloseableNetwork;
import dk.alexandra.fresco.framework.network.Network;
import dk.alexandra.fresco.framework.sce.SecureComputationEngineImpl;
import dk.alexandra.fresco.framework.sce.evaluator.BatchedProtocolEvaluator;
import dk.alexandra.fresco.framework.sce.evaluator.BatchedStrategy;
import dk.alexandra.fresco.framework.sce.evaluator.EvaluationStrategy;
import dk.alexandra.fresco.framework.sce.resources.storage.FilebasedStreamedStorageImpl;
import dk.alexandra.fresco.framework.sce.resources.storage.InMemoryStorage;
import dk.alexandra.fresco.framework.util.AesCtrDrbg;
import dk.alexandra.fresco.framework.util.Drbg;
import dk.alexandra.fresco.framework.util.ModulusFinder;
import dk.alexandra.fresco.framework.util.OpenedValueStoreImpl;
import dk.alexandra.fresco.framework.util.PaddingAesCtrDrbg;
import dk.alexandra.fresco.framework.value.SInt;
import dk.alexandra.fresco.lib.field.integer.BasicNumericContext;
import dk.alexandra.fresco.lib.real.RealNumericContext;
import dk.alexandra.fresco.logging.BatchEvaluationLoggingDecorator;
import dk.alexandra.fresco.logging.DefaultPerformancePrinter;
import dk.alexandra.fresco.logging.EvaluatorLoggingDecorator;
import dk.alexandra.fresco.logging.NetworkLoggingDecorator;
import dk.alexandra.fresco.logging.NumericSuiteLogging;
import dk.alexandra.fresco.logging.PerformanceLogger;
import dk.alexandra.fresco.logging.PerformanceLoggerCountingAggregate;
import dk.alexandra.fresco.suite.spdz.configuration.PreprocessingStrategy;
import dk.alexandra.fresco.suite.spdz.datatypes.SpdzSInt;
import dk.alexandra.fresco.suite.spdz.storage.SpdzDummyDataSupplier;
import dk.alexandra.fresco.suite.spdz.storage.SpdzMascotDataSupplier;
import dk.alexandra.fresco.suite.spdz.storage.SpdzStorageDataSupplier;
import dk.alexandra.fresco.tools.mascot.field.FieldElement;
import dk.alexandra.fresco.tools.ot.base.DummyOt;
import dk.alexandra.fresco.tools.ot.otextension.RotList;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/* loaded from: input_file:dk/alexandra/fresco/suite/spdz/AbstractSpdzTest.class */
public abstract class AbstractSpdzTest {
    private static final int DEFAULT_MOD_BIT_LENGTH = 512;
    private static final int DEFAULT_MAX_BIT_LENGTH = 150;
    private static final int DEFAULT_FIXED_POINT_PRECISION = 16;
    private static final int PRG_SEED_LENGTH = 256;
    protected Map<Integer, PerformanceLogger> performanceLoggers = new HashMap();
    private int modBitLength = DEFAULT_MOD_BIT_LENGTH;
    private int maxBitLength = DEFAULT_MAX_BIT_LENGTH;
    private int fixedPointPrecision = DEFAULT_FIXED_POINT_PRECISION;

    /* JADX INFO: Access modifiers changed from: protected */
    public void runTest(TestThreadRunner.TestThreadFactory<SpdzResourcePool, ProtocolBuilderNumeric> testThreadFactory, EvaluationStrategy evaluationStrategy, PreprocessingStrategy preprocessingStrategy, int i, boolean z, int i2, int i3, int i4) {
        this.modBitLength = i2;
        this.maxBitLength = i3;
        this.fixedPointPrecision = i4;
        ArrayList arrayList = new ArrayList(i);
        for (int i5 = 1; i5 <= i; i5++) {
            arrayList.add(Integer.valueOf(9000 + (i5 * (i - 1))));
        }
        NetManager netManager = new NetManager(arrayList);
        NetManager netManager2 = new NetManager(arrayList);
        NetManager netManager3 = new NetManager(arrayList);
        Map networkConfigurations = NetworkTestUtils.getNetworkConfigurations(i, arrayList);
        HashMap hashMap = new HashMap();
        Iterator it = networkConfigurations.keySet().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            PerformanceLogger performanceLoggerCountingAggregate = new PerformanceLoggerCountingAggregate();
            PerformanceLogger createProtocolSuite = createProtocolSuite(i3);
            if (z) {
                createProtocolSuite = new NumericSuiteLogging(createProtocolSuite);
                performanceLoggerCountingAggregate.add(createProtocolSuite);
            }
            PerformanceLogger strategy = evaluationStrategy.getStrategy();
            if (z) {
                strategy = new BatchEvaluationLoggingDecorator(strategy);
                performanceLoggerCountingAggregate.add(strategy);
            }
            PerformanceLogger batchedProtocolEvaluator = new BatchedProtocolEvaluator(strategy, createProtocolSuite);
            if (z) {
                batchedProtocolEvaluator = new EvaluatorLoggingDecorator(batchedProtocolEvaluator);
                performanceLoggerCountingAggregate.add(batchedProtocolEvaluator);
            }
            hashMap.put(Integer.valueOf(intValue), new TestThreadRunner.TestThreadConfiguration(new SecureComputationEngineImpl(createProtocolSuite, batchedProtocolEvaluator), () -> {
                return createResourcePool(intValue, i, preprocessingStrategy, netManager2, netManager, netManager3);
            }, () -> {
                AsyncNetwork asyncNetwork = new AsyncNetwork((NetworkConfiguration) networkConfigurations.get(Integer.valueOf(intValue)));
                if (!z) {
                    return asyncNetwork;
                }
                NetworkLoggingDecorator networkLoggingDecorator = new NetworkLoggingDecorator(asyncNetwork);
                performanceLoggerCountingAggregate.add(networkLoggingDecorator);
                return networkLoggingDecorator;
            }));
            this.performanceLoggers.putIfAbsent(Integer.valueOf(intValue), performanceLoggerCountingAggregate);
        }
        TestThreadRunner.run(testThreadFactory, hashMap);
        DefaultPerformancePrinter defaultPerformancePrinter = new DefaultPerformancePrinter();
        Iterator<PerformanceLogger> it2 = this.performanceLoggers.values().iterator();
        while (it2.hasNext()) {
            defaultPerformancePrinter.printPerformanceLog(it2.next());
        }
        netManager.close();
        netManager3.close();
    }

    protected SpdzProtocolSuite createProtocolSuite(int i) {
        return new SpdzProtocolSuite(i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void runTest(TestThreadRunner.TestThreadFactory<SpdzResourcePool, ProtocolBuilderNumeric> testThreadFactory, EvaluationStrategy evaluationStrategy, PreprocessingStrategy preprocessingStrategy, int i, boolean z) {
        runTest(testThreadFactory, evaluationStrategy, preprocessingStrategy, i, z, DEFAULT_MOD_BIT_LENGTH, DEFAULT_MAX_BIT_LENGTH, DEFAULT_FIXED_POINT_PRECISION);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void runTest(TestThreadRunner.TestThreadFactory<SpdzResourcePool, ProtocolBuilderNumeric> testThreadFactory, EvaluationStrategy evaluationStrategy, PreprocessingStrategy preprocessingStrategy, int i) {
        runTest(testThreadFactory, evaluationStrategy, preprocessingStrategy, i, false, DEFAULT_MOD_BIT_LENGTH, DEFAULT_MAX_BIT_LENGTH, DEFAULT_FIXED_POINT_PRECISION);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void runTest(TestThreadRunner.TestThreadFactory<SpdzResourcePool, ProtocolBuilderNumeric> testThreadFactory, EvaluationStrategy evaluationStrategy, PreprocessingStrategy preprocessingStrategy, int i, int i2, int i3, int i4) {
        runTest(testThreadFactory, evaluationStrategy, preprocessingStrategy, i, false, i2, i3, i4);
    }

    DRes<List<DRes<SInt>>> createPipe(int i, int i2, int i3, CloseableNetwork closeableNetwork, SpdzMascotDataSupplier spdzMascotDataSupplier) {
        ProtocolBuilderNumeric createSequential = new SpdzBuilder(new BasicNumericContext(this.maxBitLength, spdzMascotDataSupplier.getModulus(), i, i2), new RealNumericContext(this.fixedPointPrecision)).createSequential();
        SpdzResourcePoolImpl spdzResourcePoolImpl = new SpdzResourcePoolImpl(i, i2, new OpenedValueStoreImpl(), spdzMascotDataSupplier, new AesCtrDrbg(new byte[32]));
        DRes<List<DRes<SInt>>> exponentiationPipe = new DefaultPreprocessedValues(createSequential).getExponentiationPipe(i3);
        evaluate(createSequential, spdzResourcePoolImpl, closeableNetwork);
        return exponentiationPipe;
    }

    private Drbg getDrbg(int i, int i2) {
        byte[] bArr = new byte[i2 / 8];
        new Random(i).nextBytes(bArr);
        return new PaddingAesCtrDrbg(bArr);
    }

    private Map<Integer, RotList> getSeedOts(int i, List<Integer> list, int i2, Drbg drbg, Network network) {
        HashMap hashMap = new HashMap();
        for (Integer num : list) {
            if (i != num.intValue()) {
                DummyOt dummyOt = new DummyOt(num.intValue(), network);
                RotList rotList = new RotList(drbg, i2);
                if (i < num.intValue()) {
                    rotList.send(dummyOt);
                    rotList.receive(dummyOt);
                } else {
                    rotList.receive(dummyOt);
                    rotList.send(dummyOt);
                }
                hashMap.put(num, rotList);
            }
        }
        return hashMap;
    }

    private SpdzResourcePool createResourcePool(final int i, final int i2, PreprocessingStrategy preprocessingStrategy, NetManager netManager, NetManager netManager2, final NetManager netManager3) {
        SpdzDummyDataSupplier spdzStorageDataSupplier;
        if (preprocessingStrategy == PreprocessingStrategy.DUMMY) {
            spdzStorageDataSupplier = new SpdzDummyDataSupplier(i, i2, ModulusFinder.findSuitableModulus(this.modBitLength));
        } else if (preprocessingStrategy == PreprocessingStrategy.MASCOT) {
            List<Integer> list = (List) IntStream.range(1, i2 + 1).boxed().collect(Collectors.toList());
            final Drbg drbg = getDrbg(i, PRG_SEED_LENGTH);
            final BigInteger findSuitableModulus = ModulusFinder.findSuitableModulus(this.modBitLength);
            final Map<Integer, RotList> seedOts = getSeedOts(i, list, PRG_SEED_LENGTH, drbg, netManager.createExtraNetwork(i));
            final FieldElement createRandomSsk = SpdzMascotDataSupplier.createRandomSsk(findSuitableModulus, PRG_SEED_LENGTH);
            spdzStorageDataSupplier = SpdzMascotDataSupplier.createSimpleSupplier(i, i2, () -> {
                return netManager2.createExtraNetwork(i);
            }, this.modBitLength, findSuitableModulus, new Function<Integer, SpdzSInt[]>() { // from class: dk.alexandra.fresco.suite.spdz.AbstractSpdzTest.1
                private SpdzMascotDataSupplier tripleSupplier;
                private CloseableNetwork pipeNetwork;

                @Override // java.util.function.Function
                public SpdzSInt[] apply(Integer num) {
                    if (this.pipeNetwork == null) {
                        this.pipeNetwork = netManager3.createExtraNetwork(i);
                        this.tripleSupplier = SpdzMascotDataSupplier.createSimpleSupplier(i, i2, () -> {
                            return this.pipeNetwork;
                        }, AbstractSpdzTest.this.modBitLength, findSuitableModulus, (Function) null, seedOts, drbg, createRandomSsk);
                    }
                    return AbstractSpdzTest.this.computeSInts(AbstractSpdzTest.this.createPipe(i, i2, num.intValue(), this.pipeNetwork, this.tripleSupplier));
                }
            }, seedOts, drbg, createRandomSsk);
        } else {
            spdzStorageDataSupplier = new SpdzStorageDataSupplier(new FilebasedStreamedStorageImpl(new InMemoryStorage()), "spdz/SPDZ_1_" + i + "_0_", i2);
        }
        return new SpdzResourcePoolImpl(i, i2, new OpenedValueStoreImpl(), spdzStorageDataSupplier, new AesCtrDrbg(new byte[32]));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public SpdzSInt[] computeSInts(DRes<List<DRes<SInt>>> dRes) {
        List list = (List) dRes.out();
        SpdzSInt[] spdzSIntArr = new SpdzSInt[list.size()];
        for (int i = 0; i < list.size(); i++) {
            spdzSIntArr[i] = (SpdzSInt) ((DRes) list.get(i)).out();
        }
        return spdzSIntArr;
    }

    private void evaluate(ProtocolBuilderNumeric protocolBuilderNumeric, SpdzResourcePool spdzResourcePool, Network network) {
        new BatchedProtocolEvaluator(new BatchedStrategy(), createProtocolSuite(this.maxBitLength)).eval(protocolBuilderNumeric.build(), spdzResourcePool, network);
    }
}
