package dk.alexandra.fresco.suite.spdz.storage;

import dk.alexandra.fresco.framework.util.TransposeUtils;
import dk.alexandra.fresco.suite.spdz.datatypes.SpdzInputMask;
import dk.alexandra.fresco.suite.spdz.datatypes.SpdzSInt;
import dk.alexandra.fresco.suite.spdz.datatypes.SpdzTriple;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:dk/alexandra/fresco/suite/spdz/storage/TestSpdzDummyDataSupplier.class */
public class TestSpdzDummyDataSupplier {
    private final List<BigInteger> moduli = Arrays.asList(new BigInteger("251"), new BigInteger("340282366920938463463374607431768211283"), new BigInteger("2582249878086908589655919172003011874329705792829223512830659356540647622016841194629645353280137831435903171972747493557"));

    private List<SpdzDummyDataSupplier> setupSuppliers(int i, BigInteger bigInteger) {
        return setupSuppliers(i, bigInteger, 200);
    }

    private List<SpdzDummyDataSupplier> setupSuppliers(int i, BigInteger bigInteger, int i2) {
        ArrayList arrayList = new ArrayList(i);
        Random random = new Random();
        for (int i3 = 0; i3 < i; i3++) {
            arrayList.add(new SpdzDummyDataSupplier(i3 + 1, i, bigInteger, new BigInteger(bigInteger.bitLength(), random).mod(bigInteger), i2));
        }
        return arrayList;
    }

    private BigInteger getMacKeyFromSuppliers(List<SpdzDummyDataSupplier> list) {
        BigInteger bigInteger = BigInteger.ZERO;
        Iterator<SpdzDummyDataSupplier> it = list.iterator();
        while (it.hasNext()) {
            bigInteger = bigInteger.add(it.next().getSecretSharedKey());
        }
        return bigInteger.mod(list.get(0).getModulus());
    }

    private void testGetNextTriple(int i, BigInteger bigInteger) {
        List<SpdzDummyDataSupplier> list = setupSuppliers(i, bigInteger);
        BigInteger macKeyFromSuppliers = getMacKeyFromSuppliers(list);
        ArrayList arrayList = new ArrayList(i);
        Iterator<SpdzDummyDataSupplier> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getNextTriple());
        }
        assertTripleValid(recombineTriples(arrayList), macKeyFromSuppliers, bigInteger);
    }

    private void testGetNextTriple(int i) {
        Iterator<BigInteger> it = this.moduli.iterator();
        while (it.hasNext()) {
            testGetNextTriple(i, it.next());
        }
    }

    private void testGetNextInputMask(int i, int i2, BigInteger bigInteger) {
        List<SpdzDummyDataSupplier> list = setupSuppliers(i, bigInteger);
        BigInteger macKeyFromSuppliers = getMacKeyFromSuppliers(list);
        ArrayList arrayList = new ArrayList(i);
        Iterator<SpdzDummyDataSupplier> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getNextInputMask(i2));
        }
        BigInteger bigInteger2 = null;
        ArrayList arrayList2 = new ArrayList(i);
        for (int i3 = 0; i3 < i; i3++) {
            SpdzInputMask spdzInputMask = (SpdzInputMask) arrayList.get(i3);
            if (i3 + 1 != i2) {
                Assert.assertEquals((Object) null, spdzInputMask.getRealValue());
            } else {
                bigInteger2 = spdzInputMask.getRealValue();
            }
            arrayList2.add(spdzInputMask.getMask());
        }
        SpdzSInt recombine = recombine(arrayList2);
        assertMacCorrect(recombine, macKeyFromSuppliers, bigInteger);
        Assert.assertEquals(bigInteger2, recombine.getShare());
    }

    private void testGetNextInputMask(int i, int i2) {
        Iterator<BigInteger> it = this.moduli.iterator();
        while (it.hasNext()) {
            testGetNextInputMask(i, i2, it.next());
        }
    }

    private void testGetNextBit(int i, BigInteger bigInteger) {
        List<SpdzDummyDataSupplier> list = setupSuppliers(i, bigInteger);
        BigInteger macKeyFromSuppliers = getMacKeyFromSuppliers(list);
        ArrayList arrayList = new ArrayList(i);
        Iterator<SpdzDummyDataSupplier> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getNextBit());
        }
        SpdzSInt recombine = recombine(arrayList);
        assertMacCorrect(recombine, macKeyFromSuppliers, bigInteger);
        BigInteger share = recombine.getShare();
        Assert.assertTrue("Value not a bit " + share, share.equals(BigInteger.ZERO) || share.equals(BigInteger.ONE));
    }

    private void testGetNextBit(int i) {
        Iterator<BigInteger> it = this.moduli.iterator();
        while (it.hasNext()) {
            testGetNextBit(i, it.next());
        }
    }

    private void testGetNextRandomFieldElement(int i, BigInteger bigInteger) {
        List<SpdzDummyDataSupplier> list = setupSuppliers(i, bigInteger);
        BigInteger macKeyFromSuppliers = getMacKeyFromSuppliers(list);
        ArrayList arrayList = new ArrayList(i);
        Iterator<SpdzDummyDataSupplier> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getNextRandomFieldElement());
        }
        SpdzSInt recombine = recombine(arrayList);
        assertMacCorrect(recombine, macKeyFromSuppliers, bigInteger);
        if (bigInteger.equals(new BigInteger("251"))) {
            return;
        }
        Assert.assertFalse("Random value was 0 ", recombine.getShare().equals(BigInteger.ZERO));
    }

    private void testGetNextRandomFieldElement(int i) {
        Iterator<BigInteger> it = this.moduli.iterator();
        while (it.hasNext()) {
            testGetNextRandomFieldElement(i, it.next());
        }
    }

    private void testGetNextExpPipe(int i, BigInteger bigInteger, int i2) {
        List<SpdzDummyDataSupplier> list = setupSuppliers(i, bigInteger);
        BigInteger macKeyFromSuppliers = getMacKeyFromSuppliers(list);
        ArrayList arrayList = new ArrayList(i);
        Iterator<SpdzDummyDataSupplier> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getNextExpPipe());
        }
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            Assert.assertEquals(i2 + 1, ((SpdzSInt[]) it2.next()).length);
        }
        assertExpPipeValid(recombineExpPipe((List) arrayList.stream().map(spdzSIntArr -> {
            return (List) Arrays.stream(spdzSIntArr).collect(Collectors.toList());
        }).collect(Collectors.toList())), macKeyFromSuppliers, bigInteger);
    }

    private void testGetNextExpPipe(int i) {
        Iterator<BigInteger> it = this.moduli.iterator();
        while (it.hasNext()) {
            testGetNextExpPipe(i, it.next(), 200);
        }
    }

    @Test
    public void testGetNextTriple() {
        testGetNextTriple(2);
        testGetNextTriple(3);
        testGetNextTriple(5);
    }

    @Test
    public void testGetNextExpPipe() {
        testGetNextExpPipe(2);
        testGetNextExpPipe(3);
        testGetNextExpPipe(5);
    }

    @Test
    public void testGetNextInputMask() {
        Iterator it = Arrays.asList(2, 3, 5).iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            for (int i = 0; i < intValue; i++) {
                testGetNextInputMask(intValue, i + 1);
            }
        }
    }

    @Test
    public void testGetNextBit() {
        testGetNextBit(2);
        testGetNextBit(3);
        testGetNextBit(5);
    }

    @Test
    public void testGetNextRandomFieldElement() {
        testGetNextRandomFieldElement(2);
        testGetNextRandomFieldElement(3);
        testGetNextRandomFieldElement(5);
    }

    @Test
    public void testGetters() {
        SpdzDummyDataSupplier spdzDummyDataSupplier = new SpdzDummyDataSupplier(1, 2, this.moduli.get(0), BigInteger.ONE);
        Assert.assertEquals(this.moduli.get(0), spdzDummyDataSupplier.getModulus());
        Assert.assertEquals(BigInteger.ONE, spdzDummyDataSupplier.getSecretSharedKey());
    }

    private SpdzSInt recombine(List<SpdzSInt> list) {
        return list.stream().reduce((v0, v1) -> {
            return v0.add(v1);
        }).get();
    }

    private List<SpdzSInt> recombineExpPipe(List<List<SpdzSInt>> list) {
        return (List) TransposeUtils.transpose(list).stream().map(this::recombine).collect(Collectors.toList());
    }

    private SpdzTriple recombineTriples(List<SpdzTriple> list) {
        ArrayList arrayList = new ArrayList(list.size());
        ArrayList arrayList2 = new ArrayList(list.size());
        ArrayList arrayList3 = new ArrayList(list.size());
        for (SpdzTriple spdzTriple : list) {
            arrayList.add(spdzTriple.getA());
            arrayList2.add(spdzTriple.getB());
            arrayList3.add(spdzTriple.getC());
        }
        return new SpdzTriple(recombine(arrayList), recombine(arrayList2), recombine(arrayList3));
    }

    private void assertMacCorrect(SpdzSInt spdzSInt, BigInteger bigInteger, BigInteger bigInteger2) {
        Assert.assertEquals(spdzSInt.getShare().multiply(bigInteger).mod(bigInteger2), spdzSInt.getMac());
    }

    private void assertTripleValid(SpdzTriple spdzTriple, BigInteger bigInteger, BigInteger bigInteger2) {
        assertMacCorrect(spdzTriple.getA(), bigInteger, bigInteger2);
        assertMacCorrect(spdzTriple.getB(), bigInteger, bigInteger2);
        assertMacCorrect(spdzTriple.getC(), bigInteger, bigInteger2);
        Assert.assertEquals(spdzTriple.getC().getShare(), spdzTriple.getA().getShare().multiply(spdzTriple.getB().getShare()).mod(bigInteger2));
    }

    private void assertExpPipeValid(List<SpdzSInt> list, BigInteger bigInteger, BigInteger bigInteger2) {
        Iterator<SpdzSInt> it = list.iterator();
        while (it.hasNext()) {
            assertMacCorrect(it.next(), bigInteger, bigInteger2);
        }
        List list2 = (List) list.stream().map((v0) -> {
            return v0.getShare();
        }).collect(Collectors.toList());
        BigInteger bigInteger3 = (BigInteger) list2.get(0);
        BigInteger bigInteger4 = (BigInteger) list2.get(1);
        Assert.assertEquals(bigInteger3, bigInteger4.modInverse(bigInteger2));
        for (int i = 1; i < list2.size(); i++) {
            Assert.assertEquals(bigInteger4.modPow(BigInteger.valueOf(i), bigInteger2), list2.get(i));
        }
    }
}
