package edu.umass.cs.mallet.base.extract.test;

import edu.umass.cs.mallet.base.extract.CRFExtractor;
import edu.umass.cs.mallet.base.extract.Extraction;
import edu.umass.cs.mallet.base.extract.LatticeViewer;
import edu.umass.cs.mallet.base.fst.CRF4;
import edu.umass.cs.mallet.base.fst.MEMM;
import edu.umass.cs.mallet.base.fst.TokenAccuracyEvaluator;
import edu.umass.cs.mallet.base.fst.tests.TestCRF;
import edu.umass.cs.mallet.base.fst.tests.TestMEMM;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.pipe.SerialPipes;
import edu.umass.cs.mallet.base.pipe.iterator.ArrayIterator;
import edu.umass.cs.mallet.base.pipe.iterator.PipeInputIterator;
import edu.umass.cs.mallet.base.types.InstanceList;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:edu/umass/cs/mallet/base/extract/test/TestLatticeViewer.class */
public class TestLatticeViewer extends TestCase {
    private static File htmlFile = new File("/Scratch/output/errors.html");
    private static File latticeFile = new File("/Scratch/output/lattice.html");
    private static File htmlDir = new File("/Scratch/output/html/");

    public TestLatticeViewer(String str) {
        super(str);
    }

    public void testSpaceViewer() throws FileNotFoundException {
        Pipe makeSpacePredictionPipe = TestMEMM.makeSpacePredictionPipe();
        String[] strArr = {TestCRF.data[0]};
        String[] strArr2 = {TestCRF.data[1]};
        InstanceList instanceList = new InstanceList(makeSpacePredictionPipe);
        instanceList.add(new ArrayIterator(strArr));
        new InstanceList(makeSpacePredictionPipe).add(new ArrayIterator(strArr2));
        CRF4 crf4 = new CRF4(makeSpacePredictionPipe, (Pipe) null);
        crf4.addFullyConnectedStatesForLabels();
        crf4.train(instanceList, null, null, null);
        CRFExtractor hackCrfExtor = hackCrfExtor(crf4);
        Extraction extract = hackCrfExtor.extract((PipeInputIterator) new ArrayIterator(strArr2));
        PrintStream printStream = new PrintStream(new FileOutputStream(htmlFile));
        LatticeViewer.extraction2html(extract, hackCrfExtor, printStream);
        printStream.close();
        PrintStream printStream2 = new PrintStream(new FileOutputStream(latticeFile));
        LatticeViewer.extraction2html(extract, hackCrfExtor, printStream2, true);
        printStream2.close();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static CRFExtractor hackCrfExtor(CRF4 crf4) {
        Pipe[] pipeArr = new Pipe[3];
        SerialPipes serialPipes = (SerialPipes) crf4.getInputPipe();
        for (int i = 0; i < 3; i++) {
            Pipe pipe = serialPipes.getPipe(0);
            serialPipes.removePipe(0);
            pipe.setParent(null);
            pipeArr[i] = pipe;
        }
        return new CRFExtractor(crf4, new SerialPipes(pipeArr));
    }

    public void testDualSpaceViewer() throws IOException {
        Pipe makeSpacePredictionPipe = TestMEMM.makeSpacePredictionPipe();
        String[] strArr = {TestCRF.data[0]};
        String[] strArr2 = TestCRF.data;
        InstanceList instanceList = new InstanceList(makeSpacePredictionPipe);
        instanceList.add(new ArrayIterator(strArr));
        InstanceList instanceList2 = new InstanceList(makeSpacePredictionPipe);
        instanceList2.add(new ArrayIterator(strArr2));
        CRF4 crf4 = new CRF4(makeSpacePredictionPipe, (Pipe) null);
        crf4.addFullyConnectedStatesForLabels();
        crf4.train(instanceList, null, instanceList2, new TokenAccuracyEvaluator(), 5);
        CRFExtractor hackCrfExtor = hackCrfExtor(crf4);
        Extraction extract = hackCrfExtor.extract((PipeInputIterator) new ArrayIterator(strArr2));
        Pipe makeSpacePredictionPipe2 = TestMEMM.makeSpacePredictionPipe();
        InstanceList instanceList3 = new InstanceList(makeSpacePredictionPipe2);
        instanceList3.add(new ArrayIterator(strArr));
        InstanceList instanceList4 = new InstanceList(makeSpacePredictionPipe2);
        instanceList4.add(new ArrayIterator(strArr2));
        MEMM memm = new MEMM(makeSpacePredictionPipe2, (Pipe) null);
        memm.addFullyConnectedStatesForLabels();
        memm.train(instanceList3, null, instanceList4, new TokenAccuracyEvaluator(), 5);
        CRFExtractor hackCrfExtor2 = hackCrfExtor(memm);
        LatticeViewer.viewDualResults(htmlDir, extract, hackCrfExtor, hackCrfExtor2.extract((PipeInputIterator) new ArrayIterator(strArr2)), hackCrfExtor2);
    }

    public static Test suite() {
        return new TestSuite(TestLatticeViewer.class);
    }

    public static void main(String[] strArr) throws Throwable {
        TestSuite suite;
        if (strArr.length > 0) {
            suite = new TestSuite();
            for (String str : strArr) {
                suite.addTest(new TestLatticeViewer(str));
            }
        } else {
            suite = suite();
        }
        TestRunner.run(suite);
    }
}
