package cc.mallet.grmm.test;

import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.Variable;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.types.tests.TestSerializable;
import cc.mallet.util.ArrayUtils;
import cc.mallet.util.Maths;
import cc.mallet.util.Randoms;
import java.io.IOException;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:cc/mallet/grmm/test/TestLogTableFactor.class */
public class TestLogTableFactor extends TestCase {
    public TestLogTableFactor(String str) {
        super(str);
    }

    public void testTimesTableFactor() {
        Variable variable = new Variable(4);
        TableFactor tableFactor = new TableFactor(variable, new double[]{1.0d, 2.0d, 3.0d, 4.0d});
        TableFactor tableFactor2 = new TableFactor(variable, new double[]{2.0d, 4.0d, 6.0d, 8.0d});
        tableFactor2.multiplyBy(LogTableFactor.makeFromValues(variable, new double[]{0.5d, 0.5d, 0.5d, 0.5d}));
        assertTrue(tableFactor.almostEquals(tableFactor2));
    }

    public void testTblTblPlusEquals() {
        Variable variable = new Variable(4);
        LogTableFactor makeFromValues = LogTableFactor.makeFromValues(variable, new double[]{2.25d, 4.5d, 6.75d, 9.0d});
        LogTableFactor makeFromValues2 = LogTableFactor.makeFromValues(variable, new double[]{2.0d, 4.0d, 6.0d, 8.0d});
        makeFromValues2.plusEquals(LogTableFactor.makeFromValues(variable, new double[]{0.25d, 0.5d, 0.75d, 1.0d}));
        assertTrue(makeFromValues.almostEquals(makeFromValues2));
    }

    public void testMultiplyByLogSpace() {
        Variable variable = new Variable(4);
        double[] dArr = {2.0d, 4.0d, 6.0d, 8.0d};
        double[] dArr2 = {0.5d, 0.5d, 0.5d, 0.5d};
        TableFactor tableFactor = new TableFactor(variable, new double[]{1.0d, 2.0d, 3.0d, 4.0d});
        TableFactor tableFactor2 = new TableFactor(variable, dArr);
        tableFactor2.multiplyBy(new TableFactor(variable, dArr2));
        assertTrue(tableFactor.almostEquals(tableFactor2));
        TableFactor tableFactor3 = new TableFactor(variable, dArr);
        tableFactor3.multiplyBy(LogTableFactor.makeFromValues(variable, dArr2));
        assertTrue(tableFactor3.almostEquals(tableFactor2));
        TableFactor tableFactor4 = new TableFactor(variable, dArr);
        LogTableFactor makeFromValues = LogTableFactor.makeFromValues(variable, dArr2);
        makeFromValues.multiplyBy(tableFactor4);
        assertTrue(makeFromValues.almostEquals(tableFactor));
        LogTableFactor makeFromValues2 = LogTableFactor.makeFromValues(variable, dArr);
        LogTableFactor makeFromValues3 = LogTableFactor.makeFromValues(variable, dArr2);
        makeFromValues3.multiplyBy(makeFromValues2);
        assertTrue(makeFromValues3.almostEquals(tableFactor));
    }

    public void testDivideByLogSpace() {
        Variable variable = new Variable(4);
        double[] dArr = {2.0d, 4.0d, 6.0d, 8.0d};
        double[] dArr2 = {0.5d, 0.5d, 0.5d, 0.5d};
        TableFactor tableFactor = new TableFactor(variable, new double[]{4.0d, 8.0d, 12.0d, 16.0d});
        TableFactor tableFactor2 = new TableFactor(variable, dArr);
        tableFactor2.divideBy(new TableFactor(variable, dArr2));
        assertTrue(tableFactor.almostEquals(tableFactor2));
        TableFactor tableFactor3 = new TableFactor(variable, dArr);
        tableFactor3.divideBy(LogTableFactor.makeFromValues(variable, dArr2));
        assertTrue(tableFactor3.almostEquals(tableFactor));
        LogTableFactor makeFromValues = LogTableFactor.makeFromValues(variable, dArr);
        makeFromValues.divideBy(new TableFactor(variable, dArr2));
        assertTrue(makeFromValues.almostEquals(tableFactor));
        LogTableFactor makeFromValues2 = LogTableFactor.makeFromValues(variable, dArr);
        makeFromValues2.divideBy(LogTableFactor.makeFromValues(variable, dArr2));
        assertTrue(makeFromValues2.almostEquals(tableFactor));
    }

    public void testEntropyLogSpace() {
        Variable variable = new Variable(2);
        assertEquals(0.61086d, new TableFactor(variable, new double[]{0.3d, 0.7d}).entropy(), 0.001d);
        assertEquals(0.61086d, LogTableFactor.makeFromValues(variable, new double[]{0.3d, 0.7d}).entropy(), 0.001d);
    }

    public void ignoreTestSerialization() throws IOException, ClassNotFoundException {
        Variable variable = new Variable(2);
        LogTableFactor makeFromLogValues = LogTableFactor.makeFromLogValues(new Variable[]{variable, new Variable(3)}, new double[]{2.0d, 4.0d, 6.0d, 3.0d, 5.0d, 7.0d});
        LogTableFactor logTableFactor = (LogTableFactor) TestSerializable.cloneViaSerialization(makeFromLogValues);
        assertTrue(!makeFromLogValues.varSet().contains(logTableFactor.varSet()));
        comparePotentialValues(makeFromLogValues, logTableFactor);
        comparePotentialValues((LogTableFactor) makeFromLogValues.marginalize(variable), (LogTableFactor) logTableFactor.marginalize(logTableFactor.findVariable(variable.getLabel())));
    }

    private void comparePotentialValues(LogTableFactor logTableFactor, LogTableFactor logTableFactor2) {
        AssignmentIterator assignmentIterator = logTableFactor.assignmentIterator();
        AssignmentIterator assignmentIterator2 = logTableFactor2.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            assertTrue(logTableFactor.value(assignmentIterator) == logTableFactor.value(assignmentIterator2));
            assignmentIterator.advance();
            assignmentIterator2.advance();
        }
    }

    public void testExtractMaxLogSpace() {
        Variable[] variableArr = {new Variable(2), new Variable(2)};
        LogTableFactor logTableFactor = (LogTableFactor) LogTableFactor.makeFromValues(variableArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d}).extractMax(variableArr[1]);
        assertEquals("FAILURE: Potential has too many vars.\n  " + logTableFactor, 1, logTableFactor.varSet().size());
        assertTrue("FAILURE: Potential does not contain " + variableArr[1] + ":\n  " + logTableFactor, logTableFactor.varSet().contains(variableArr[1]));
        double[] dArr = {3.0d, 4.0d};
        assertTrue("FAILURE: Potential has incorrect values.  Expected " + ArrayUtils.toString(dArr) + "was " + logTableFactor, Maths.almostEquals(logTableFactor.toValueArray(), dArr, 1.0E-5d));
    }

    public void testLogValue() {
        Variable[] variableArr = {new Variable(2), new Variable(2)};
        LogTableFactor makeFromValues = LogTableFactor.makeFromValues(variableArr, new double[]{1.0d, 2.0d, 3.0d, 4.0d});
        Assignment assignment = new Assignment(variableArr, new int[variableArr.length]);
        assertEquals(0.0d, makeFromValues.logValue(assignment), 1.0E-5d);
        assertEquals(0.0d, makeFromValues.logValue(makeFromValues.assignmentIterator()), 1.0E-5d);
        assertEquals(0.0d, makeFromValues.logValue(0), 1.0E-5d);
        assertEquals(1.0d, makeFromValues.value(assignment), 1.0E-5d);
        assertEquals(1.0d, makeFromValues.value(makeFromValues.assignmentIterator()), 1.0E-5d);
        assertEquals(1.0d, makeFromValues.value(0), 1.0E-5d);
        LogTableFactor makeFromLogValues = LogTableFactor.makeFromLogValues(variableArr, new double[]{0.0d, Math.log(2.0d), Math.log(3.0d), Math.log(4.0d)});
        Assignment assignment2 = new Assignment(variableArr, new int[variableArr.length]);
        assertEquals(0.0d, makeFromLogValues.logValue(assignment2), 1.0E-5d);
        assertEquals(0.0d, makeFromLogValues.logValue(makeFromLogValues.assignmentIterator()), 1.0E-5d);
        assertEquals(0.0d, makeFromLogValues.logValue(0), 1.0E-5d);
        assertEquals(1.0d, makeFromLogValues.value(assignment2), 1.0E-5d);
        assertEquals(1.0d, makeFromLogValues.value(makeFromLogValues.assignmentIterator()), 1.0E-5d);
        assertEquals(1.0d, makeFromLogValues.value(0), 1.0E-5d);
    }

    public void testOneVarSlice() {
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        LogTableFactor makeFromLogValues = LogTableFactor.makeFromLogValues(new Variable[]{variable, variable2}, new double[]{0.0d, 1.3862943611198906d, 0.6931471805599453d, 1.791759469228055d});
        Assignment assignment = new Assignment(variable, 0);
        comparePotentialValues((LogTableFactor) makeFromLogValues.slice(assignment), LogTableFactor.makeFromValues(variable2, new double[]{1.0d, 4.0d}));
        assertEquals(1, assignment.varSet().size());
    }

    public void testTwoVarSlice() {
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        Variable variable3 = new Variable(2);
        comparePotentialValues((LogTableFactor) LogTableFactor.makeFromValues(new Variable[]{variable, variable2, variable3}, new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d}).slice(new Assignment(variable3, 0)), LogTableFactor.makeFromValues(new Variable[]{variable, variable2}, new double[]{0.0d, 2.0d, 4.0d, 6.0d}));
    }

    public void testMultiVarSlice() {
        Variable variable = new Variable(2);
        Variable variable2 = new Variable(2);
        Variable variable3 = new Variable(2);
        Variable variable4 = new Variable(2);
        LogTableFactor makeFromValues = LogTableFactor.makeFromValues(new Variable[]{variable, variable2, variable3, variable4}, new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d, 10.0d, 11.0d, 12.0d, 13.0d, 14.0d, 15.0d});
        System.out.println(makeFromValues);
        comparePotentialValues((LogTableFactor) makeFromValues.slice(new Assignment(variable4, 0)), LogTableFactor.makeFromValues(new Variable[]{variable, variable2, variable3}, new double[]{0.0d, 2.0d, 4.0d, 6.0d, 8.0d, 10.0d, 12.0d, 14.0d}));
    }

    public void testSparseValueAndLogValue() {
        LogTableFactor makeFromMatrix = LogTableFactor.makeFromMatrix(new Variable[]{new Variable(2), new Variable(2)}, new SparseMatrixn(new int[]{2, 2}, new int[]{1, 3}, new double[]{4.0d, 8.0d}));
        AssignmentIterator assignmentIterator = makeFromMatrix.assignmentIterator();
        assertEquals(1, assignmentIterator.indexOfCurrentAssn());
        assertEquals(Math.log(4.0d), makeFromMatrix.logValue(assignmentIterator), 1.0E-5d);
        assertEquals(Math.log(4.0d), makeFromMatrix.logValue(assignmentIterator.assignment()), 1.0E-5d);
        assertEquals(4.0d, makeFromMatrix.value(assignmentIterator), 1.0E-5d);
        assertEquals(4.0d, makeFromMatrix.value(assignmentIterator.assignment()), 1.0E-5d);
        AssignmentIterator assignmentIterator2 = makeFromMatrix.varSet().assignmentIterator();
        assertEquals(0, assignmentIterator2.indexOfCurrentAssn());
        assertEquals(Double.NEGATIVE_INFINITY, makeFromMatrix.logValue(assignmentIterator2), 1.0E-5d);
        assertEquals(Double.NEGATIVE_INFINITY, makeFromMatrix.logValue(assignmentIterator2.assignment()), 1.0E-5d);
        assertEquals(0.0d, makeFromMatrix.value(assignmentIterator2), 1.0E-5d);
        assertEquals(0.0d, makeFromMatrix.value(assignmentIterator2.assignment()), 1.0E-5d);
    }

    public void testSparseMultiplyLogSpace() {
        Variable[] variableArr = {new Variable(2), new Variable(2)};
        int[] iArr = {2, 2};
        int[] iArr2 = {0, 1, 3};
        LogTableFactor makeFromMatrix = LogTableFactor.makeFromMatrix(variableArr, new SparseMatrixn(iArr, iArr2, new double[]{2.0d, 4.0d, 8.0d}));
        LogTableFactor makeFromMatrix2 = LogTableFactor.makeFromMatrix(variableArr, new SparseMatrixn(iArr, new int[]{0, 3}, new double[]{0.5d, 0.5d}));
        LogTableFactor makeFromMatrix3 = LogTableFactor.makeFromMatrix(variableArr, new SparseMatrixn(iArr, iArr2, new double[]{1.0d, 0.0d, 4.0d}));
        Factor multiply = makeFromMatrix.multiply(makeFromMatrix2);
        assertTrue("Tast failed! Expected: " + makeFromMatrix3 + " Actual: " + multiply, makeFromMatrix3.almostEquals(multiply));
    }

    public void testSparseDivideLogSpace() {
        Variable[] variableArr = {new Variable(2), new Variable(2)};
        int[] iArr = {2, 2};
        int[] iArr2 = {0, 1, 3};
        LogTableFactor makeFromMatrix = LogTableFactor.makeFromMatrix(variableArr, new SparseMatrixn(iArr, iArr2, new double[]{2.0d, 4.0d, 8.0d}));
        LogTableFactor makeFromMatrix2 = LogTableFactor.makeFromMatrix(variableArr, new SparseMatrixn(iArr, new int[]{0, 3}, new double[]{0.5d, 0.5d}));
        LogTableFactor makeFromMatrix3 = LogTableFactor.makeFromMatrix(variableArr, new SparseMatrixn(iArr, iArr2, new double[]{4.0d, 0.0d, 16.0d}));
        makeFromMatrix.divideBy(makeFromMatrix2);
        assertTrue("Tast failed! Expected: " + makeFromMatrix3 + " Actual: " + makeFromMatrix, makeFromMatrix3.almostEquals(makeFromMatrix));
    }

    public void testSparseMarginalizeLogSpace() {
        Variable[] variableArr = {new Variable(2), new Variable(2)};
        LogTableFactor makeFromMatrix = LogTableFactor.makeFromMatrix(variableArr, new SparseMatrixn(new int[]{2, 2}, new int[]{0, 1, 3}, new double[]{2.0d, 4.0d, 8.0d}));
        LogTableFactor makeFromValues = LogTableFactor.makeFromValues(variableArr[0], new double[]{6.0d, 8.0d});
        Factor marginalize = makeFromMatrix.marginalize(variableArr[0]);
        assertTrue("Tast failed! Expected: " + makeFromValues + " Actual: " + marginalize + " Orig: " + makeFromMatrix, makeFromValues.almostEquals(marginalize));
    }

    public void testLogSample() {
        assertEquals(1, LogTableFactor.makeFromLogValues(new Variable(2), new double[]{-30.0d, 0.0d}).sampleLocation(new Randoms(43)));
    }

    public void testPlusEquals() {
        Variable variable = new Variable(4);
        LogTableFactor makeFromLogValues = LogTableFactor.makeFromLogValues(variable, new double[]{Double.NEGATIVE_INFINITY, 0.0d, 0.6931471805599453d, 1.0986122886681098d});
        makeFromLogValues.plusEquals(0.1d);
        LogTableFactor makeFromLogValues2 = LogTableFactor.makeFromLogValues(variable, new double[]{-2.3025850929940455d, 0.09531017980432493d, 0.7419373447293773d, 1.1314021114911006d});
        assertTrue("Error: expected " + makeFromLogValues2.dumpToString() + " but was " + makeFromLogValues.dumpToString(), makeFromLogValues.almostEquals(makeFromLogValues2));
    }

    public void testRecenter() {
        Variable variable = new Variable(4);
        LogTableFactor makeFromValues = LogTableFactor.makeFromValues(variable, new double[]{2.0d, 4.0d, 6.0d, 8.0d});
        makeFromValues.recenter();
        LogTableFactor makeFromLogValues = LogTableFactor.makeFromLogValues(variable, new double[]{Math.log(0.25d), Math.log(0.5d), Math.log(0.75d), 0.0d});
        assertTrue("Error: expected " + makeFromLogValues.dumpToString() + "but was " + makeFromValues.dumpToString(), makeFromLogValues.almostEquals(makeFromValues));
    }

    public void testRecenter2() {
        Variable variable = new Variable(4);
        LogTableFactor makeFromLogValues = LogTableFactor.makeFromLogValues(variable, new double[]{0.0d, 1.4d, 1.4d, 0.0d});
        makeFromLogValues.recenter();
        LogTableFactor makeFromLogValues2 = LogTableFactor.makeFromLogValues(variable, new double[]{-1.4d, 0.0d, 0.0d, -1.4d});
        assertTrue(!makeFromLogValues.isNaN());
        assertTrue("Error: expected " + makeFromLogValues2.dumpToString() + "but was " + makeFromLogValues.dumpToString(), makeFromLogValues2.almostEquals(makeFromLogValues));
    }

    public static Test suite() {
        return new TestSuite(TestLogTableFactor.class);
    }

    public static void main(String[] strArr) throws Throwable {
        TestSuite suite;
        if (strArr.length > 0) {
            suite = new TestSuite();
            for (String str : strArr) {
                suite.addTest(new TestLogTableFactor(str));
            }
        } else {
            suite = suite();
        }
        TestRunner.run(suite);
    }
}
