package dk.alexandra.fresco.lib.statistics;

import dk.alexandra.fresco.framework.DRes;
import dk.alexandra.fresco.framework.TestThreadRunner;
import dk.alexandra.fresco.framework.builder.numeric.Numeric;
import dk.alexandra.fresco.framework.builder.numeric.ProtocolBuilderNumeric;
import dk.alexandra.fresco.framework.builder.numeric.field.FieldDefinition;
import dk.alexandra.fresco.framework.sce.resources.ResourcePool;
import dk.alexandra.fresco.framework.value.SInt;
import dk.alexandra.fresco.lib.lp.LPSolver;
import dk.alexandra.fresco.lib.statistics.DeaSolver;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.hamcrest.core.IsNull;
import org.junit.Assert;

/* loaded from: input_file:dk/alexandra/fresco/lib/statistics/DeaSolverTests.class */
public class DeaSolverTests {

    /* loaded from: input_file:dk/alexandra/fresco/lib/statistics/DeaSolverTests$RandomDataDeaTest.class */
    public static class RandomDataDeaTest<ResourcePoolT extends ResourcePool> extends TestDeaSolver<ResourcePoolT> {
        private static final int BIT_LENGTH = 9;

        public RandomDataDeaTest(int i, int i2, int i3, int i4, DeaSolver.AnalysisType analysisType) {
            this(i, i2, i3, i4, analysisType, new Random(2L));
        }

        private RandomDataDeaTest(int i, int i2, int i3, int i4, DeaSolver.AnalysisType analysisType, Random random) {
            this(i, i2, i3, i4, analysisType, random, 50, false);
        }

        public RandomDataDeaTest(int i, int i2, int i3, int i4, DeaSolver.AnalysisType analysisType, Random random, int i5, boolean z) {
            super(randomMatrix(i3, i, random), randomMatrix(i3, i2, random), randomMatrix(i4, i, random), randomMatrix(i4, i2, random), analysisType, i5, z);
        }

        private static List<List<BigInteger>> randomMatrix(int i, int i2, Random random) {
            ArrayList arrayList = new ArrayList();
            for (int i3 = 0; i3 < i; i3++) {
                ArrayList arrayList2 = new ArrayList();
                for (int i4 = 0; i4 < i2; i4++) {
                    arrayList2.add(new BigInteger(BIT_LENGTH, random));
                }
                arrayList.add(arrayList2);
            }
            return arrayList;
        }
    }

    /* loaded from: input_file:dk/alexandra/fresco/lib/statistics/DeaSolverTests$TestDeaFixed1.class */
    public static class TestDeaFixed1<ResourcePoolT extends ResourcePool> extends TestDeaSolver<ResourcePoolT> {
        private static List<List<BigInteger>> inputs;
        private static List<List<BigInteger>> outputs;

        public TestDeaFixed1(DeaSolver.AnalysisType analysisType) {
            super(inputs, outputs, inputs, outputs, analysisType, 50, false);
        }

        /* JADX WARN: Type inference failed for: r0v1, types: [int[], int[][]] */
        static {
            ?? r0 = {new int[]{29, 13451, 14409, 16477}, new int[]{2, 581, 531, 1037}, new int[]{26, 13352, 1753, 13528}, new int[]{15, 4828, 949, 5126}, new int[]{20, 6930, 6376, 9680}};
            inputs = DeaSolverTests.buildInputs(r0);
            outputs = DeaSolverTests.buildOutputs(r0);
        }
    }

    /* loaded from: input_file:dk/alexandra/fresco/lib/statistics/DeaSolverTests$TestDeaFixed2.class */
    public static class TestDeaFixed2<ResourcePoolT extends ResourcePool> extends TestDeaSolver<ResourcePoolT> {
        private static List<List<BigInteger>> inputs;
        private static List<List<BigInteger>> outputs;

        public TestDeaFixed2(DeaSolver.AnalysisType analysisType) {
            super(inputs, outputs, inputs, outputs, analysisType, 50, false);
        }

        /* JADX WARN: Type inference failed for: r0v1, types: [int[], int[][]] */
        static {
            ?? r0 = {new int[]{10, 20, 30, 1000}, new int[]{5, 10, 15, 1000}, new int[]{200, 300, 400, 100}};
            inputs = DeaSolverTests.buildInputs(r0);
            outputs = DeaSolverTests.buildOutputs(r0);
        }
    }

    /* loaded from: input_file:dk/alexandra/fresco/lib/statistics/DeaSolverTests$TestDeaSolver.class */
    public static class TestDeaSolver<ResourcePoolT extends ResourcePool> extends TestThreadRunner.TestThreadFactory<ResourcePoolT, ProtocolBuilderNumeric> {
        private final List<List<BigInteger>> rawTargetOutputs;
        private final List<List<BigInteger>> rawTargetInputs;
        private final List<List<BigInteger>> rawBasisOutputs;
        private final List<List<BigInteger>> rawBasisInputs;
        private final DeaSolver.AnalysisType type;
        private final int maxNoOfIterations;
        private boolean willBreak;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:dk/alexandra/fresco/lib/statistics/DeaSolverTests$TestDeaSolver$OpenDeaResult.class */
        public static class OpenDeaResult {
            private final BigInteger optimal;
            private final BigInteger numerator;
            private final BigInteger denominator;
            private final List<BigInteger> peers;
            private final List<BigInteger> peerValues;

            private OpenDeaResult(BigInteger bigInteger, BigInteger bigInteger2, BigInteger bigInteger3, List<BigInteger> list, List<BigInteger> list2) {
                this.optimal = bigInteger;
                this.numerator = bigInteger2;
                this.denominator = bigInteger3;
                this.peers = list;
                this.peerValues = list2;
            }

            static DRes<OpenDeaResult> createOpenDeaResult(Numeric numeric, DeaSolver.DeaResult deaResult, FieldDefinition fieldDefinition) {
                DRes open = numeric.open(deaResult.optimal);
                DRes open2 = numeric.open(deaResult.numerator);
                DRes open3 = numeric.open(deaResult.denominator);
                Stream stream = deaResult.peers.stream();
                numeric.getClass();
                List list = (List) stream.map((v1) -> {
                    return r1.open(v1);
                }).collect(Collectors.toList());
                Stream stream2 = deaResult.peerValues.stream();
                numeric.getClass();
                List list2 = (List) stream2.map((v1) -> {
                    return r1.open(v1);
                }).collect(Collectors.toList());
                return () -> {
                    BigInteger convertToSigned = fieldDefinition.convertToSigned((BigInteger) open.out());
                    BigInteger convertToSigned2 = fieldDefinition.convertToSigned((BigInteger) open2.out());
                    BigInteger convertToSigned3 = fieldDefinition.convertToSigned((BigInteger) open3.out());
                    Stream map = list.stream().map((v0) -> {
                        return v0.out();
                    });
                    fieldDefinition.getClass();
                    List list3 = (List) map.map(fieldDefinition::convertToSigned).collect(Collectors.toList());
                    Stream map2 = list2.stream().map((v0) -> {
                        return v0.out();
                    });
                    fieldDefinition.getClass();
                    return new OpenDeaResult(convertToSigned, convertToSigned2, convertToSigned3, list3, (List) map2.map(fieldDefinition::convertToSigned).collect(Collectors.toList()));
                };
            }
        }

        public TestDeaSolver(List<List<BigInteger>> list, List<List<BigInteger>> list2, List<List<BigInteger>> list3, List<List<BigInteger>> list4, DeaSolver.AnalysisType analysisType, int i, boolean z) {
            this.rawTargetOutputs = list4;
            this.rawTargetInputs = list3;
            this.rawBasisOutputs = list2;
            this.rawBasisInputs = list;
            this.type = analysisType;
            this.maxNoOfIterations = i;
            this.willBreak = z;
        }

        public TestThreadRunner.TestThread<ResourcePoolT, ProtocolBuilderNumeric> next() {
            return (TestThreadRunner.TestThread<ResourcePoolT, ProtocolBuilderNumeric>) new TestThreadRunner.TestThread<ResourcePoolT, ProtocolBuilderNumeric>() { // from class: dk.alexandra.fresco.lib.statistics.DeaSolverTests.TestDeaSolver.1
                private BigInteger modulus;

                public void test() {
                    List list = (List) runApplication((DeaSolver) runApplication(protocolBuilderNumeric -> {
                        this.modulus = protocolBuilderNumeric.getBasicNumericContext().getModulus();
                        Numeric numeric = protocolBuilderNumeric.numeric();
                        List knownMatrix = DeaSolverTests.knownMatrix(numeric, TestDeaSolver.this.rawTargetOutputs);
                        List knownMatrix2 = DeaSolverTests.knownMatrix(numeric, TestDeaSolver.this.rawTargetInputs);
                        List knownMatrix3 = DeaSolverTests.knownMatrix(numeric, TestDeaSolver.this.rawBasisOutputs);
                        List knownMatrix4 = DeaSolverTests.knownMatrix(numeric, TestDeaSolver.this.rawBasisInputs);
                        return () -> {
                            return new DeaSolver(LPSolver.PivotRule.DANZIG, TestDeaSolver.this.type, knownMatrix2, knownMatrix, knownMatrix4, knownMatrix3, TestDeaSolver.this.maxNoOfIterations);
                        };
                    }));
                    if (TestDeaSolver.this.willBreak) {
                        list.forEach(deaResult -> {
                            Assert.assertThat(deaResult.optimal, IsNull.nullValue());
                        });
                        return;
                    }
                    List list2 = (List) runApplication(protocolBuilderNumeric2 -> {
                        Numeric numeric = protocolBuilderNumeric2.numeric();
                        ArrayList arrayList = new ArrayList();
                        Iterator it = list.iterator();
                        while (it.hasNext()) {
                            arrayList.add(OpenDeaResult.createOpenDeaResult(numeric, (DeaSolver.DeaResult) it.next(), protocolBuilderNumeric2.getBasicNumericContext().getFieldDefinition()));
                        }
                        return () -> {
                            return (List) arrayList.stream().map((v0) -> {
                                return v0.out();
                            }).collect(Collectors.toList());
                        };
                    });
                    PlaintextDEASolver plaintextDEASolver = new PlaintextDEASolver();
                    plaintextDEASolver.addBasis(TestDeaSolver.this.asArray(TestDeaSolver.this.rawBasisInputs), TestDeaSolver.this.asArray(TestDeaSolver.this.rawBasisOutputs));
                    double[] solve = plaintextDEASolver.solve(TestDeaSolver.this.asArray(TestDeaSolver.this.rawTargetInputs), TestDeaSolver.this.asArray(TestDeaSolver.this.rawTargetOutputs), TestDeaSolver.this.type);
                    int size = TestDeaSolver.this.type == DeaSolver.AnalysisType.INPUT_EFFICIENCY ? TestDeaSolver.this.rawBasisInputs.size() : TestDeaSolver.this.rawBasisInputs.size() + 1;
                    for (int i = 0; i < TestDeaSolver.this.rawTargetInputs.size(); i++) {
                        OpenDeaResult openDeaResult = (OpenDeaResult) list2.get(i);
                        Assert.assertEquals(solve[i], DeaSolverTests.postProcess(openDeaResult.optimal, TestDeaSolver.this.type, this.modulus), 1.0E-7d);
                        double doubleValue = openDeaResult.numerator.doubleValue() / openDeaResult.denominator.doubleValue();
                        if (TestDeaSolver.this.type == DeaSolver.AnalysisType.INPUT_EFFICIENCY) {
                            doubleValue *= -1.0d;
                        }
                        Assert.assertEquals(solve[i], doubleValue, 1.0E-7d);
                        List list3 = openDeaResult.peers;
                        List list4 = openDeaResult.peerValues;
                        double d = 0.0d;
                        Iterator it = list4.iterator();
                        while (it.hasNext()) {
                            d += DeaSolverTests.postProcess((BigInteger) it.next(), DeaSolver.AnalysisType.OUTPUT_EFFICIENCY, this.modulus);
                        }
                        Assert.assertEquals("Peer values summed to " + d + " instead of 1", 1.0d, d, 1.0E-6d);
                        Iterator it2 = list3.iterator();
                        while (it2.hasNext()) {
                            int intValue = ((BigInteger) it2.next()).intValue();
                            Assert.assertTrue("Peer index" + intValue + ", was larger than " + (size - 1), intValue < size);
                        }
                        for (int i2 = 0; i2 < ((List) TestDeaSolver.this.rawTargetInputs.get(i)).size(); i2++) {
                            List constraintRow = TestDeaSolver.this.getConstraintRow(i2, TestDeaSolver.this.rawBasisInputs);
                            double d2 = 0.0d;
                            for (int i3 = 0; i3 < constraintRow.size(); i3++) {
                                int indexOf = list3.indexOf(BigInteger.valueOf(i3));
                                if (indexOf > -1) {
                                    d2 += DeaSolverTests.postProcess((BigInteger) list4.get(indexOf), DeaSolver.AnalysisType.OUTPUT_EFFICIENCY, this.modulus) * ((BigInteger) constraintRow.get(i3)).doubleValue();
                                }
                            }
                            Assert.assertTrue(d2 - (TestDeaSolver.this.type == DeaSolver.AnalysisType.INPUT_EFFICIENCY ? ((BigInteger) ((List) TestDeaSolver.this.rawTargetInputs.get(i)).get(i2)).doubleValue() * solve[i] : ((BigInteger) ((List) TestDeaSolver.this.rawTargetInputs.get(i)).get(i2)).doubleValue()) < 1.0E-6d);
                        }
                        for (int i4 = 0; i4 < ((List) TestDeaSolver.this.rawTargetOutputs.get(i)).size(); i4++) {
                            List constraintRow2 = TestDeaSolver.this.getConstraintRow(i4, TestDeaSolver.this.rawBasisOutputs);
                            if (TestDeaSolver.this.type == DeaSolver.AnalysisType.OUTPUT_EFFICIENCY) {
                                constraintRow2.add(((List) TestDeaSolver.this.rawTargetOutputs.get(i)).get(i4));
                            }
                            double d3 = 0.0d;
                            for (int i5 = 0; i5 < constraintRow2.size(); i5++) {
                                int indexOf2 = list3.indexOf(BigInteger.valueOf(i5));
                                if (indexOf2 > -1) {
                                    d3 += DeaSolverTests.postProcess((BigInteger) list4.get(indexOf2), DeaSolver.AnalysisType.OUTPUT_EFFICIENCY, this.modulus) * ((BigInteger) constraintRow2.get(i5)).doubleValue();
                                }
                            }
                            Assert.assertTrue((TestDeaSolver.this.type == DeaSolver.AnalysisType.INPUT_EFFICIENCY ? ((BigInteger) ((List) TestDeaSolver.this.rawTargetOutputs.get(i)).get(i4)).doubleValue() : ((BigInteger) ((List) TestDeaSolver.this.rawTargetOutputs.get(i)).get(i4)).doubleValue() * solve[i]) - d3 < 1.0E-6d);
                        }
                    }
                }
            };
        }

        /* JADX INFO: Access modifiers changed from: private */
        public List<BigInteger> getConstraintRow(int i, List<List<BigInteger>> list) {
            ArrayList arrayList = new ArrayList();
            Iterator<List<BigInteger>> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().get(i));
            }
            return arrayList;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public BigInteger[][] asArray(List<List<BigInteger>> list) {
            return (BigInteger[][]) list.stream().map(list2 -> {
                return (BigInteger[]) list2.toArray(new BigInteger[0]);
            }).toArray(i -> {
                return new BigInteger[i];
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<List<BigInteger>> buildInputs(int[][] iArr) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < iArr.length; i++) {
            arrayList.add(new ArrayList());
            for (int i2 = 0; i2 < iArr[0].length - 1; i2++) {
                ((List) arrayList.get(i)).add(BigInteger.valueOf(iArr[i][i2]));
            }
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<List<BigInteger>> buildOutputs(int[][] iArr) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < iArr.length; i++) {
            arrayList.add(new ArrayList());
            ((List) arrayList.get(i)).add(BigInteger.valueOf(iArr[i][iArr[i].length - 1]));
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<List<DRes<SInt>>> knownMatrix(Numeric numeric, List<List<BigInteger>> list) {
        return (List) list.stream().map(list2 -> {
            Stream stream = list2.stream();
            numeric.getClass();
            return (List) stream.map(numeric::known).collect(Collectors.toList());
        }).collect(Collectors.toList());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double postProcess(BigInteger bigInteger, DeaSolver.AnalysisType analysisType, BigInteger bigInteger2) {
        BigInteger[] gauss = gauss(bigInteger, bigInteger2);
        double doubleValue = gauss[0].doubleValue() / gauss[1].doubleValue();
        if (analysisType == DeaSolver.AnalysisType.INPUT_EFFICIENCY) {
            doubleValue *= -1.0d;
        }
        return doubleValue;
    }

    private static BigInteger[] gauss(BigInteger bigInteger, BigInteger bigInteger2) {
        BigInteger bigInteger3;
        BigInteger mod = bigInteger.mod(bigInteger2);
        BigInteger[] bigIntegerArr = {bigInteger2, BigInteger.ZERO};
        BigInteger[] bigIntegerArr2 = {mod, BigInteger.ONE};
        BigInteger valueOf = BigInteger.valueOf(2L);
        BigInteger innerproduct = innerproduct(bigIntegerArr, bigIntegerArr2);
        BigInteger innerproduct2 = innerproduct(bigIntegerArr2, bigIntegerArr2);
        do {
            BigInteger[] divideAndRemainder = innerproduct.divideAndRemainder(innerproduct2);
            if (divideAndRemainder[1].signum() == -1) {
                if (innerproduct2.compareTo(divideAndRemainder[1].multiply(valueOf.negate())) <= 0) {
                    divideAndRemainder[0] = divideAndRemainder[0].subtract(BigInteger.ONE);
                }
            } else if (innerproduct2.compareTo(divideAndRemainder[1].multiply(valueOf)) <= 0) {
                divideAndRemainder[0] = divideAndRemainder[0].add(BigInteger.ONE);
            }
            BigInteger subtract = bigIntegerArr[0].subtract(bigIntegerArr2[0].multiply(divideAndRemainder[0]));
            BigInteger subtract2 = bigIntegerArr[1].subtract(bigIntegerArr2[1].multiply(divideAndRemainder[0]));
            bigIntegerArr = bigIntegerArr2;
            bigIntegerArr2 = new BigInteger[]{subtract, subtract2};
            bigInteger3 = innerproduct2;
            innerproduct = innerproduct(bigIntegerArr, bigIntegerArr2);
            innerproduct2 = innerproduct(bigIntegerArr2, bigIntegerArr2);
        } while (bigInteger3.compareTo(innerproduct2) > 0);
        return new BigInteger[]{bigIntegerArr[0], bigIntegerArr[1]};
    }

    private static BigInteger innerproduct(BigInteger[] bigIntegerArr, BigInteger[] bigIntegerArr2) {
        return bigIntegerArr[0].multiply(bigIntegerArr2[0]).add(bigIntegerArr[1].multiply(bigIntegerArr2[1]));
    }
}
