package cc.mallet.grmm.test;

import cc.mallet.grmm.inference.RandomGraphs;
import cc.mallet.grmm.types.AbstractTableFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.ConstantFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.UndirectedGrid;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.ModelReader;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.Randoms;
import cc.mallet.util.Timing;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.StringReader;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;
import org.junit.Ignore;

/* loaded from: input_file:cc/mallet/grmm/test/TestFactorGraph.class */
public class TestFactorGraph extends TestCase {
    private Variable[] vars;
    private TableFactor tbl1;
    private TableFactor tbl2;
    private TableFactor tbl3;
    private LogTableFactor ltbl1;
    private LogTableFactor ltbl2;
    private static String uniformMdlstr = "VAR sigma u1 u2 : continuous\nVAR x1 x2 : 2\nsigma ~ Uniform -0.5 0.5\nu1 ~ Uniform -0.5 0.5\nu2 ~ Uniform -0.5 0.5\nx1 x2 ~ BinaryPair sigma\nx1 ~ Unary u1\nx2 ~ Unary u2\n";
    private static String uniformMdlstr2 = "VAR sigma u1 u2 : continuous\nVAR x1 x2 : 2\nsigma ~ Normal 0.0 0.2\nu1 ~ Normal 0.0 0.2\nu2 ~ Normal 0.0 0.2\nx1 x2 ~ BinaryPair sigma\nx1 ~ Unary u1\nx2 ~ Unary u2\n";

    public TestFactorGraph(String str) {
        super(str);
    }

    @Override // junit.framework.TestCase
    protected void setUp() throws Exception {
        this.vars = new Variable[]{new Variable(2), new Variable(2), new Variable(2), new Variable(2)};
        this.tbl1 = new TableFactor(new Variable[]{this.vars[0], this.vars[1]}, new double[]{0.8d, 0.1d, 0.1d, 0.8d});
        this.tbl2 = new TableFactor(new Variable[]{this.vars[1], this.vars[2]}, new double[]{0.2d, 0.7d, 0.8d, 0.2d});
        this.tbl3 = new TableFactor(new Variable[]{this.vars[2], this.vars[3]}, new double[]{0.2d, 0.4d, 0.6d, 0.4d});
        this.ltbl1 = LogTableFactor.makeFromValues(new Variable[]{this.vars[0], this.vars[1]}, new double[]{0.8d, 0.1d, 0.1d, 0.8d});
        this.ltbl2 = LogTableFactor.makeFromValues(new Variable[]{this.vars[1], this.vars[2]}, new double[]{0.2d, 0.7d, 0.8d, 0.2d});
    }

    public void testMultiplyBy() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        assertEquals(2, factorGraph.factors().size());
        assertTrue(factorGraph.factors().contains(this.tbl1));
        assertTrue(factorGraph.factors().contains(this.tbl2));
        assertEquals(3, factorGraph.numVariables());
        assertTrue(factorGraph.variablesSet().contains(this.vars[0]));
        assertTrue(factorGraph.variablesSet().contains(this.vars[1]));
        assertTrue(factorGraph.variablesSet().contains(this.vars[2]));
    }

    public void testNumVariables() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        assertEquals(3, factorGraph.numVariables());
    }

    public void testMultiply() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        FactorGraph factorGraph2 = (FactorGraph) factorGraph.multiply(this.tbl3);
        assertEquals(2, factorGraph.factors().size());
        assertEquals(3, factorGraph2.factors().size());
        assertTrue(!factorGraph.factors().contains(this.tbl3));
        assertTrue(factorGraph2.factors().contains(this.tbl3));
    }

    public void testValue() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        assertEquals(0.08d, factorGraph.value(new Assignment(factorGraph.varSet().toVariableArray(), new int[]{0, 1, 0})), 1.0E-5d);
    }

    public void testMarginalize() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        assertTrue(new TableFactor(this.vars[1], new double[]{0.81d, 0.9d}).almostEquals(factorGraph.marginalize(this.vars[1])));
    }

    public void testSum() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        assertEquals(1.71d, factorGraph.sum(), 1.0E-5d);
    }

    public void testNormalize() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        factorGraph.normalize();
        assertEquals(1.0d, factorGraph.sum(), 1.0E-5d);
    }

    public void testLogNormalize() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.ltbl1);
        factorGraph.multiplyBy(this.ltbl2);
        factorGraph.normalize();
        assertEquals(1.0d, factorGraph.sum(), 1.0E-5d);
    }

    public void testEmbeddedFactorGraph() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        FactorGraph factorGraph2 = new FactorGraph();
        factorGraph2.multiplyBy(factorGraph);
        factorGraph2.multiplyBy(this.tbl3);
        assertEquals(4, factorGraph2.varSet().size());
        assertEquals(2, factorGraph2.factors().size());
        Assignment assignment = new Assignment(factorGraph2.varSet().toVariableArray(), new int[4]);
        assertEquals(0.032d, factorGraph2.value(assignment), 1.0E-5d);
        AbstractTableFactor asTable = factorGraph2.asTable();
        assertEquals(4, asTable.varSet().size());
        assertEquals(0.032d, asTable.value(assignment), 1.0E-5d);
    }

    public void testAsTable() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        assertTrue(((AbstractTableFactor) this.tbl1.multiply(this.tbl2)).almostEquals(factorGraph.asTable()));
    }

    public void testTableTimesFg() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        Factor multiply = this.tbl3.multiply(factorGraph);
        assertTrue(multiply instanceof AbstractTableFactor);
        assertEquals(4, multiply.varSet().size());
        assertEquals(0.032d, multiply.value(new Assignment(multiply.varSet().toVariableArray(), new int[4])), 1.0E-5d);
    }

    public void testLogTableTimesFg() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        Factor multiply = this.ltbl1.multiply(factorGraph);
        assertTrue(multiply instanceof AbstractTableFactor);
        assertEquals(3, multiply.varSet().size());
        assertEquals(0.128d, multiply.value(new Assignment(multiply.varSet().toVariableArray(), new int[3])), 1.0E-5d);
    }

    public void testRemove() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        assertEquals(2, factorGraph.getDegree(this.vars[1]));
        factorGraph.divideBy(this.tbl1);
        assertEquals(2, factorGraph.varSet().size());
        assertEquals(0.2d, factorGraph.value(new Assignment(factorGraph.varSet().toVariableArray(), new int[2])), 1.0E-5d);
        int i = 0;
        Iterator varSetIterator = factorGraph.varSetIterator();
        while (varSetIterator.hasNext()) {
            i++;
            varSetIterator.next();
        }
        assertEquals(1, i);
        assertEquals(1, factorGraph.getDegree(this.vars[1]));
        assertTrue(factorGraph.get(0) != factorGraph.get(1));
        assertEquals(this.vars[1], factorGraph.get(0));
        assertEquals(this.vars[2], factorGraph.get(1));
    }

    public void testRedundantDomains() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        factorGraph.multiplyBy(this.ltbl1);
        assertEquals(3, factorGraph.varSet().size());
        assertEquals("Wrong factors in FG, was " + factorGraph.dumpToString(), 3, factorGraph.factors().size());
        assertEquals(0.128d, factorGraph.value(new Assignment(factorGraph.varSet().toVariableArray(), new int[3])), 1.0E-5d);
    }

    @Ignore
    public void testContinousSample() throws IOException {
        FactorGraph readModel = new ModelReader().readModel(new BufferedReader(new StringReader(uniformMdlstr)));
        Randoms randoms = new Randoms(324143);
        Assignment assignment = new Assignment();
        for (int i = 0; i < 10000; i++) {
            assignment.addRow(readModel.sample(randoms));
        }
        Variable findVariable = readModel.findVariable("x1");
        int[] columnInt = ((Assignment) assignment.marginalize(findVariable)).getColumnInt(findVariable);
        assertEquals(0.5d, MatrixOps.sum(columnInt) / columnInt.length, 0.025d);
    }

    @Ignore
    public void testContinousSample2() throws IOException {
        FactorGraph readModel = new ModelReader().readModel(new BufferedReader(new StringReader(uniformMdlstr2)));
        Randoms randoms = new Randoms(324143);
        Assignment assignment = new Assignment();
        for (int i = 0; i < 10000; i++) {
            assignment.addRow(readModel.sample(randoms));
        }
        Variable findVariable = readModel.findVariable("x2");
        int[] columnInt = ((Assignment) assignment.marginalize(findVariable)).getColumnInt(findVariable);
        assertEquals(0.5d, MatrixOps.sum(columnInt) / columnInt.length, 0.01d);
        Variable findVariable2 = readModel.findVariable("x2");
        int[] columnInt2 = ((Assignment) assignment.marginalize(findVariable2)).getColumnInt(findVariable2);
        assertEquals(0.5d, MatrixOps.sum(columnInt2) / columnInt2.length, 0.025d);
    }

    @Ignore
    public void testAllFactorsOf() throws IOException {
        FactorGraph readModel = new ModelReader().readModel(new BufferedReader(new StringReader(uniformMdlstr2)));
        Variable variable = new Variable(2);
        variable.setLabel("v0");
        assertEquals(0, readModel.allFactorsOf(variable).size());
    }

    public void testAllFactorsOf2() throws IOException {
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.addFactor(new TableFactor(variable));
        factorGraph.addFactor(new TableFactor(variable2));
        factorGraph.addFactor(new TableFactor(new Variable[]{variable, variable2}));
        List<Factor> allFactorsOf = factorGraph.allFactorsOf(variable);
        assertEquals(1, allFactorsOf.size());
        for (Factor factor : allFactorsOf) {
            assertEquals(1, factor.varSet().size());
            assertTrue(factor.varSet().contains(variable));
        }
        HashVarSet hashVarSet = new HashVarSet(new Variable[]{variable, variable2});
        List allFactorsOf2 = factorGraph.allFactorsOf(hashVarSet);
        assertEquals(1, allFactorsOf2.size());
        assertTrue(((Factor) allFactorsOf2.get(0)).varSet().equals(hashVarSet));
    }

    public void testAsTable2() {
        assertTrue(Arrays.equals(new double[]{1.2d, 0.8d}, new FactorGraph(new Factor[]{new TableFactor(this.vars[0], new double[]{0.6d, 0.4d}), new ConstantFactor(2.0d)}).asTable().toValueArray()));
    }

    public void testClear() {
        FactorGraph factorGraph = new FactorGraph();
        factorGraph.multiplyBy(this.tbl1);
        factorGraph.multiplyBy(this.tbl2);
        assertEquals(3, factorGraph.numVariables());
        assertEquals(2, factorGraph.factors().size());
        factorGraph.clear();
        assertEquals(0, factorGraph.numVariables());
        assertEquals(0, factorGraph.factors().size());
        for (int i = 0; i < this.tbl1.varSet().size(); i++) {
            assertTrue(!factorGraph.containsVar(this.tbl1.getVariable(i)));
        }
        for (int i2 = 0; i2 < this.tbl2.varSet().size(); i2++) {
            assertTrue(!factorGraph.containsVar(this.tbl2.getVariable(i2)));
        }
    }

    public void testCacheExpanding() {
        UndirectedGrid randomFrustratedGrid = RandomGraphs.randomFrustratedGrid(25, 1.0d, new Random(3324879L));
        Assignment assignment = new Assignment(randomFrustratedGrid, new int[randomFrustratedGrid.numVariables()]);
        double logValue = randomFrustratedGrid.logValue(assignment);
        Timing timing = new Timing();
        for (int i = 0; i < 100; i++) {
            FactorGraph factorGraph = new FactorGraph(randomFrustratedGrid.numVariables());
            for (int i2 = 0; i2 < randomFrustratedGrid.factors().size(); i2++) {
                factorGraph.multiplyBy(randomFrustratedGrid.getFactor(i2));
            }
            assertEquals(logValue, factorGraph.logValue(assignment), 1.0E-5d);
        }
        long elapsedTime = timing.elapsedTime();
        timing.tick("No-expansion time");
        for (int i3 = 0; i3 < 100; i3++) {
            FactorGraph factorGraph2 = new FactorGraph();
            for (int i4 = 0; i4 < randomFrustratedGrid.factors().size(); i4++) {
                factorGraph2.multiplyBy(randomFrustratedGrid.getFactor(i4));
            }
            assertEquals(logValue, factorGraph2.logValue(assignment), 1.0E-5d);
        }
        long elapsedTime2 = timing.elapsedTime();
        timing.tick("With-expansion time");
        assertTrue(elapsedTime < elapsedTime2);
    }

    public static Test suite() {
        return new TestSuite((Class<?>) TestFactorGraph.class);
    }

    public static void main(String[] strArr) throws Throwable {
        TestSuite testSuite;
        if (strArr.length > 0) {
            testSuite = new TestSuite();
            for (String str : strArr) {
                testSuite.addTest(new TestFactorGraph(str));
            }
        } else {
            testSuite = (TestSuite) suite();
        }
        TestRunner.run(testSuite);
    }
}
