package com.aliasi.test.unit.crf;

import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.crf.ChainCrf;
import com.aliasi.crf.ChainCrfFeatureExtractor;
import com.aliasi.crf.ChainCrfFeatures;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.symbol.SymbolTableCompiler;
import com.aliasi.tag.ScoredTagging;
import com.aliasi.tag.TagLattice;
import com.aliasi.tag.Tagging;
import com.aliasi.test.unit.Asserts;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Math;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.Strings;
import com.aliasi.xml.XHtmlWriter;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import junit.framework.Assert;
import org.junit.Test;

/* loaded from: input_file:com/aliasi/test/unit/crf/ChainCrfTest.class */
public class ChainCrfTest {
    static String CAT1 = "X";
    static String CAT2 = "Y";
    static String CAT3 = "Z";
    static String[] TAGS = {CAT1, CAT2, CAT3};
    static String X1 = XHtmlWriter.A;
    static String X2 = XHtmlWriter.B;
    static String X3 = "c";
    static String X4 = "d";
    static String[] TOKENS = {X1, X2, X3, X4};
    static String[] FEATURES = {CAT1, CAT2, CAT3, X1, X2, X3, X4};
    static double XX = 1.0d;
    static double XY = 1.0d;
    static double XZ = 2.0d;
    static double YX = 2.0d;
    static double YY = -1.0d;
    static double YZ = 4.0d;
    static double ZX = 3.0d;
    static double ZY = 1.0d;
    static double ZZ = 6.0d;
    static double[][] TRANSITION_WEIGHTS = {new double[]{XX, YX, ZX}, new double[]{XY, YY, ZY}, new double[]{XZ, YZ, ZZ}};
    static double Xa = 4.0d;
    static double Xb = 5.0d;
    static double Xc = 6.0d;
    static double Xd = 7.0d;
    static double Ya = -1.0d;
    static double Yb = 10.0d;
    static double Yc = -1.0d;
    static double Yd = 1.0d;
    static double Za = -2.0d;
    static double Zb = -4.0d;
    static double Zc = -6.0d;
    static double Zd = 15.0d;
    static double[][] TOKEN_WEIGHTS = {new double[]{Xa, Xb, Xc, Xd}, new double[]{Ya, Yb, Yc, Yd}, new double[]{Za, Zb, Zc, Zd}};
    static int NUM_TAGS = TAGS.length;
    static Vector[] COEFFICIENTS = {new DenseVector(new double[]{XX, YX, ZX, Xa, Xb, Xc, Xd}), new DenseVector(new double[]{XY, YY, ZY, Ya, Yb, Yc, Yd}), new DenseVector(new double[]{XZ, YZ, ZZ, Za, Zb, Zc, Zd})};
    static final SymbolTable FEATURE_SYMBOL_TABLE = SymbolTableCompiler.asSymbolTable(FEATURES);
    static final ChainCrfFeatureExtractor<String> FEATURE_EXTRACTOR = new TestFeatureExtractor();
    static boolean ADD_INTERCEPT_FEATURE = false;
    static ChainCrf<String> CRF = new ChainCrf<>(TAGS, COEFFICIENTS, FEATURE_SYMBOL_TABLE, FEATURE_EXTRACTOR, ADD_INTERCEPT_FEATURE);

    /* loaded from: input_file:com/aliasi/test/unit/crf/ChainCrfTest$TestCorpus.class */
    static class TestCorpus extends Corpus<ObjectHandler<Tagging<String>>> {
        static final String[][][] WORDS_TAGSS = {new String[]{new String[0], new String[0]}, new String[]{new String[]{"."}, new String[]{"EOS"}}, new String[]{new String[]{"John", "ran", "."}, new String[]{"PN", "IV", "EOS"}}, new String[]{new String[]{"Mary", "ran", "."}, new String[]{"PN", "IV", "EOS"}}, new String[]{new String[]{"John", "jumped", "!"}, new String[]{"PN", "IV", "EOS"}}, new String[]{new String[]{"The", "dog", "jumped", "!"}, new String[]{"DET", "N", "IV", "EOS"}}, new String[]{new String[]{"The", "dog", "sat", "."}, new String[]{"DET", "N", "IV", "EOS"}}, new String[]{new String[]{"Mary", "sat", "!"}, new String[]{"PN", "IV", "EOS"}}, new String[]{new String[]{"Mary", "likes", "John", "."}, new String[]{"PN", "TV", "PN", "EOS"}}, new String[]{new String[]{"The", "dog", "likes", "Mary", "."}, new String[]{"DET", "N", "TV", "PN", "EOS"}}, new String[]{new String[]{"John", "likes", "the", "dog", "."}, new String[]{"PN", "TV", "DET", "N", "EOS"}}, new String[]{new String[]{"The", "dog", "ran", "."}, new String[]{"DET", "N", "IV", "EOS"}}, new String[]{new String[]{"The", "dog", "ran", "."}, new String[]{"DET", "N", "IV", "EOS"}}};

        TestCorpus() {
        }

        @Override // com.aliasi.corpus.Corpus
        public void visitTrain(ObjectHandler<Tagging<String>> objectHandler) {
            for (String[][] strArr : WORDS_TAGSS) {
                objectHandler.handle(new Tagging<>(Arrays.asList(strArr[0]), Arrays.asList(strArr[1])));
            }
        }

        @Override // com.aliasi.corpus.Corpus
        public void visitTest(ObjectHandler<Tagging<String>> objectHandler) {
        }
    }

    /* loaded from: input_file:com/aliasi/test/unit/crf/ChainCrfTest$TestCrfFeatures.class */
    static class TestCrfFeatures extends ChainCrfFeatures<String> {
        public TestCrfFeatures(List<String> list, List<String> list2) {
            super(list, list2);
        }

        @Override // com.aliasi.crf.ChainCrfFeatures
        public Map<String, Integer> nodeFeatures(int i) {
            return Collections.singletonMap(token(i), 1);
        }

        @Override // com.aliasi.crf.ChainCrfFeatures
        public Map<String, Integer> edgeFeatures(int i, int i2) {
            return Collections.singletonMap(tag(i2), 1);
        }
    }

    /* loaded from: input_file:com/aliasi/test/unit/crf/ChainCrfTest$TestFeatureExtractor.class */
    static class TestFeatureExtractor implements ChainCrfFeatureExtractor<String>, Serializable {
        TestFeatureExtractor() {
        }

        @Override // com.aliasi.crf.ChainCrfFeatureExtractor
        public ChainCrfFeatures<String> extract(List<String> list, List<String> list2) {
            return new TestCrfFeatures(list, list2);
        }
    }

    @Test
    public void testDecoder() throws IOException {
        ChainCrf chainCrf = (ChainCrf) AbstractExternalizable.serializeDeserialize(CRF);
        Assert.assertEquals(CRF.addInterceptFeature(), chainCrf.addInterceptFeature());
        Assert.assertEquals(CRF.featureSymbolTable().numSymbols(), chainCrf.featureSymbolTable().numSymbols());
        for (int i = 0; i < CRF.featureSymbolTable().numSymbols(); i++) {
            Assert.assertEquals(CRF.featureSymbolTable().idToSymbol(i), chainCrf.featureSymbolTable().idToSymbol(i));
        }
        Assert.assertEquals(CRF.tags(), chainCrf.tags());
        Vector[] coefficients = CRF.coefficients();
        Vector[] coefficients2 = chainCrf.coefficients();
        Assert.assertEquals(coefficients.length, coefficients2.length);
        for (int i2 = 0; i2 < coefficients.length; i2++) {
            Assert.assertEquals(coefficients[i2].numDimensions(), coefficients2[i2].numDimensions());
            org.junit.Assert.assertArrayEquals(coefficients[i2].nonZeroDimensions(), coefficients2[i2].nonZeroDimensions());
            for (int i3 : coefficients[i2].nonZeroDimensions()) {
                Assert.assertEquals(coefficients[i2].value(i3), coefficients2[i2].value(i3), 1.0E-4d);
            }
        }
        for (int i4 = 0; i4 < 5; i4++) {
            for (int[] iArr : allArrays(i4, TOKENS.length)) {
                ArrayList arrayList = new ArrayList(i4);
                for (int i5 : iArr) {
                    arrayList.add(TOKENS[i5]);
                }
                ObjectToDoubleMap<int[]> bruteForce = bruteForce(iArr, TAGS.length, TRANSITION_WEIGHTS, TOKEN_WEIGHTS);
                assertCorrectAnswer(CRF, arrayList, bruteForce, TAGS);
                assertCorrectAnswer(chainCrf, arrayList, bruteForce, TAGS);
                assertCorrectNBest(bruteForce, CRF.tagNBest(arrayList, Integer.MAX_VALUE), TAGS, false);
                assertCorrectNBest(bruteForce, CRF.tagNBestConditional(arrayList, Integer.MAX_VALUE), TAGS, true);
                assertCorrectMarginal(bruteForce, CRF.tagMarginal(arrayList), TAGS, arrayList);
            }
        }
    }

    void assertCorrectMarginal(ObjectToDoubleMap<int[]> objectToDoubleMap, TagLattice<String> tagLattice, String[] strArr, List<String> list) {
        Assert.assertEquals(list, tagLattice.tokenList());
        double logZ = logZ(objectToDoubleMap);
        Assert.assertEquals(logZ, tagLattice.logZ(), 0.001d);
        List<String> tagList = tagLattice.tagList();
        for (int i = 0; i < list.size(); i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < tagList.size(); i2++) {
                d += Math.exp(tagLattice.logProbability(i, i2));
                Assert.assertEquals(logMarginal(objectToDoubleMap, i, i2, strArr.length, logZ), tagLattice.logProbability(i, i2), 1.0E-4d);
            }
            Assert.assertEquals("marginals norm " + i + Strings.SINGLE_SPACE_STRING + list, 1.0d, d, 0.01d);
        }
    }

    static double logMarginal(ObjectToDoubleMap<int[]> objectToDoubleMap, int i, int i2, int i3, double d) {
        int i4 = 0;
        Iterator<int[]> it = objectToDoubleMap.keySet().iterator();
        while (it.hasNext()) {
            if (it.next()[i] == i2) {
                i4++;
            }
        }
        double[] dArr = new double[i4];
        int i5 = 0;
        for (Map.Entry<int[], Double> entry : objectToDoubleMap.entrySet()) {
            if (i2 == entry.getKey()[i]) {
                int i6 = i5;
                i5++;
                dArr[i6] = entry.getValue().doubleValue();
            }
        }
        return Math.logSumOfExponentials(dArr) - d;
    }

    static double logZ(ObjectToDoubleMap<int[]> objectToDoubleMap) {
        double[] dArr = new double[objectToDoubleMap.size()];
        int i = 0;
        Iterator<Double> it = objectToDoubleMap.values().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            dArr[i2] = it.next().doubleValue();
        }
        return Math.logSumOfExponentials(dArr);
    }

    void assertCorrectNBest(ObjectToDoubleMap<int[]> objectToDoubleMap, Iterator<ScoredTagging<String>> it, String[] strArr, boolean z) {
        double logZ = z ? logZ(objectToDoubleMap) : 0.0d;
        ObjectToDoubleMap objectToDoubleMap2 = new ObjectToDoubleMap();
        int i = 0;
        TreeSet treeSet = new TreeSet();
        for (Map.Entry<int[], Double> entry : objectToDoubleMap.entrySet()) {
            Double value = entry.getValue();
            int[] key = entry.getKey();
            StringBuilder sb = new StringBuilder();
            for (int i2 : key) {
                sb.append(strArr[i2]);
            }
            String sb2 = sb.toString();
            objectToDoubleMap2.put(sb2, value);
            treeSet.add(sb2);
            i++;
        }
        TreeSet treeSet2 = new TreeSet();
        int i3 = 0;
        while (it.hasNext()) {
            ScoredTagging<String> next = it.next();
            double score = next.score();
            List<String> tags = next.tags();
            StringBuilder sb3 = new StringBuilder();
            Iterator<String> it2 = tags.iterator();
            while (it2.hasNext()) {
                sb3.append(it2.next());
            }
            String sb4 = sb3.toString();
            treeSet2.add(sb4);
            Assert.assertEquals(objectToDoubleMap2.get(sb4).doubleValue() - logZ, score, 1.0E-4d);
            i3++;
        }
        Assert.assertEquals(treeSet, treeSet2);
    }

    @Test
    public void testAllOutputsSizes() {
        Assert.assertEquals(1, allArrays(0, 5).size());
        Assert.assertEquals(5, allArrays(1, 5).size());
        Assert.assertEquals(25, allArrays(2, 5).size());
        Assert.assertEquals(125, allArrays(3, 5).size());
    }

    static void assertCorrectAnswer(ChainCrf<String> chainCrf, List<String> list, ObjectToDoubleMap<int[]> objectToDoubleMap, String[] strArr) {
        List<String> tags = chainCrf.tag(list).tags();
        List<int[]> keysOrderedByValueList = objectToDoubleMap.keysOrderedByValueList();
        double value = objectToDoubleMap.getValue(keysOrderedByValueList.get(0));
        for (int[] iArr : keysOrderedByValueList) {
            if (objectToDoubleMap.getValue(iArr) < value) {
                Assert.fail();
            }
            if (areEqualTags(tags, iArr, strArr)) {
                Asserts.succeed();
                return;
            }
        }
    }

    static boolean areEqualTags(List<String> list, int[] iArr, String[] strArr) {
        for (int i = 0; i < iArr.length; i++) {
            if (!list.get(i).equals(strArr[iArr[i]])) {
                return false;
            }
        }
        return true;
    }

    static ObjectToDoubleMap<int[]> bruteForce(int[] iArr, int i, double[][] dArr, double[][] dArr2) {
        ObjectToDoubleMap<int[]> objectToDoubleMap = new ObjectToDoubleMap<>();
        for (int[] iArr2 : allArrays(iArr.length, i)) {
            objectToDoubleMap.put(iArr2, Double.valueOf(score(iArr, iArr2, dArr, dArr2)));
        }
        return objectToDoubleMap;
    }

    static double score(int[] iArr, int[] iArr2, double[][] dArr, double[][] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < iArr.length; i++) {
            d += dArr2[iArr2[i]][iArr[i]];
        }
        for (int i2 = 1; i2 < iArr.length; i2++) {
            d += dArr[iArr2[i2]][iArr2[i2 - 1]];
        }
        return d;
    }

    static List<int[]> allArrays(int i, int i2) {
        ArrayList arrayList = new ArrayList();
        allArrays(i, i2, new int[i], arrayList);
        return arrayList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    static void allArrays(int i, int i2, int[] iArr, List<int[]> list) {
        if (i == 0) {
            list.add(iArr.clone());
            return;
        }
        for (int i3 = 0; i3 < i2; i3++) {
            iArr[i - 1] = i3;
            allArrays(i - 1, i2, iArr, list);
        }
    }

    @Test
    public void testEstimate() throws Exception {
        ChainCrf estimate = ChainCrf.estimate(new TestCorpus(), FEATURE_EXTRACTOR, true, 1, true, true, RegressionPrior.gaussian(10.0d, true), 3, AnnealingSchedule.exponential(0.02d, 0.995d), 1.0E-5d, 2, 2000, null);
        assertTagging(Arrays.asList("John", "ran", "."), Arrays.asList("PN", "IV", "EOS"), estimate);
        assertTagging(Arrays.asList("Mary", "ran", "."), Arrays.asList("PN", "IV", "EOS"), estimate);
        assertTagging(Arrays.asList("The", "dog", "ran", "."), Arrays.asList("DET", "N", "IV", "EOS"), estimate);
        assertTagging(Arrays.asList("The", "dog", "ran", "!"), Arrays.asList("DET", "N", "IV", "EOS"), estimate);
        assertTagging(Arrays.asList("The", "dog", "sat", "!"), Arrays.asList("DET", "N", "IV", "EOS"), estimate);
        assertTagging(Arrays.asList("The", "dog", "sat", "."), Arrays.asList("DET", "N", "IV", "EOS"), estimate);
        assertTagging(Arrays.asList("John", "likes", "Mary", "."), Arrays.asList("PN", "TV", "PN", "EOS"), estimate);
        assertTagging(Arrays.asList("Mary", "likes", "John", "."), Arrays.asList("PN", "TV", "PN", "EOS"), estimate);
        Assert.assertNotNull(estimate.tag(Arrays.asList("Fred", "likes", "John", ".")));
        Assert.assertNotNull(estimate.tag(Arrays.asList(";", ".", "likes", "likes")));
    }

    static <E> void assertTagging(List<E> list, List<String> list2, ChainCrf<E> chainCrf) {
        Assert.assertEquals(list2, chainCrf.tag(list).tags());
    }
}
