package edu.umass.cs.mallet.base.classify.tests;

import edu.umass.cs.mallet.base.classify.Classification;
import edu.umass.cs.mallet.base.classify.Classifier;
import edu.umass.cs.mallet.base.classify.NaiveBayes;
import edu.umass.cs.mallet.base.classify.NaiveBayesTrainer;
import edu.umass.cs.mallet.base.pipe.CharSequence2TokenSequence;
import edu.umass.cs.mallet.base.pipe.FeatureSequence2FeatureVector;
import edu.umass.cs.mallet.base.pipe.Input2CharSequence;
import edu.umass.cs.mallet.base.pipe.Noop;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.pipe.SerialPipes;
import edu.umass.cs.mallet.base.pipe.Target2Label;
import edu.umass.cs.mallet.base.pipe.TokenSequence2FeatureSequence;
import edu.umass.cs.mallet.base.pipe.TokenSequenceLowercase;
import edu.umass.cs.mallet.base.pipe.TokenSequenceRemoveStopwords;
import edu.umass.cs.mallet.base.pipe.iterator.ArrayIterator;
import edu.umass.cs.mallet.base.pipe.iterator.FileIterator;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.LabelAlphabet;
import edu.umass.cs.mallet.base.types.LabelVector;
import edu.umass.cs.mallet.base.types.Multinomial;
import edu.umass.cs.mallet.base.util.Random;
import java.io.File;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;

/* loaded from: input_file:edu/umass/cs/mallet/base/classify/tests/TestNaiveBayes.class */
public class TestNaiveBayes extends TestCase {
    public TestNaiveBayes(String str) {
        super(str);
    }

    public void testNonTrained() {
        Alphabet alphabet = new Alphabet();
        System.out.println("fdict.size=" + alphabet.size());
        LabelAlphabet labelAlphabet = new LabelAlphabet();
        Multinomial.LaplaceEstimator laplaceEstimator = new Multinomial.LaplaceEstimator(alphabet);
        Multinomial.LaplaceEstimator laplaceEstimator2 = new Multinomial.LaplaceEstimator(alphabet);
        labelAlphabet.lookupIndex("sports");
        labelAlphabet.lookupIndex("politics");
        labelAlphabet.stopGrowth();
        System.out.println("ldict.size=" + labelAlphabet.size());
        Multinomial multinomial = new Multinomial(new double[]{0.5d, 0.5d}, labelAlphabet);
        laplaceEstimator.increment("win", 5.0d);
        laplaceEstimator.increment("puck", 5.0d);
        laplaceEstimator.increment("team", 5.0d);
        System.out.println("fdict.size=" + alphabet.size());
        laplaceEstimator2.increment("win", 5.0d);
        laplaceEstimator2.increment("speech", 5.0d);
        laplaceEstimator2.increment("vote", 5.0d);
        NaiveBayes naiveBayes = new NaiveBayes(new Noop(alphabet, labelAlphabet), multinomial, new Multinomial[]{laplaceEstimator.estimate(), laplaceEstimator2.estimate()});
        Instance instance = new Instance(new FeatureVector(alphabet, new Object[]{"speech", "win"}, new double[]{1.0d, 1.0d}), labelAlphabet.lookupLabel("politics"), null, null, naiveBayes.getInstancePipe());
        System.out.println("inst.data = " + instance.getData());
        Classification classify = naiveBayes.classify(instance);
        System.out.println("l.getBestIndex=" + ((LabelVector) classify.getLabeling()).getBestIndex());
        assertTrue(classify.getLabeling().getBestLabel() == labelAlphabet.lookupLabel("politics"));
        assertTrue(classify.getLabeling().getBestValue() > 0.6d);
    }

    public void testStringTrained() {
        InstanceList instanceList = new InstanceList(new SerialPipes(new Pipe[]{new Target2Label(), new CharSequence2TokenSequence(), new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector()}));
        instanceList.add(new ArrayIterator(new String[]{"on the plains of africa the lions roar", "in swahili ngoma means to dance", "nelson mandela became president of south africa", "the saraha dessert is expanding"}, "africa"));
        instanceList.add(new ArrayIterator(new String[]{"panda bears eat bamboo", "china's one child policy has resulted in a surplus of boys", "tigers live in the jungle"}, "asia"));
        assertTrue(new NaiveBayesTrainer().train(instanceList).classify("nelson mandela never eats lions").getLabeling().getBestLabel() == ((LabelAlphabet) instanceList.getTargetAlphabet()).lookupLabel("africa"));
    }

    public void testRandomTrained() {
        InstanceList instanceList = new InstanceList(new Random(1L), 10, 2);
        Classifier train = new NaiveBayesTrainer().train(instanceList);
        int i = 0;
        for (int i2 = 0; i2 < instanceList.size(); i2++) {
            Instance instanceList2 = instanceList.getInstance(i2);
            Classification classify = train.classify(instanceList2);
            classify.print();
            if (classify.getLabeling().getBestLabel() == instanceList2.getLabeling().getBestLabel()) {
                i++;
            }
        }
        System.out.println("Accuracy on training set = " + (i / instanceList.size()));
    }

    public void testIncrementallyTrainedGrowingAlphabets() {
        System.out.println("testIncrementallyTrainedGrowingAlphabets");
        String[] strArr = {"/usr/gob/users1/hough/melinda-bug2/learn/project", "/usr/gob/users1/hough/melinda-bug2/learn/subfdr"};
        File[] fileArr = new File[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            fileArr[i] = new File(strArr[i]);
        }
        SerialPipes serialPipes = new SerialPipes(new Pipe[]{new Target2Label(), new Input2CharSequence(), new CharSequence2TokenSequence(), new TokenSequenceLowercase(), new TokenSequenceRemoveStopwords(), new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector()});
        InstanceList instanceList = new InstanceList(serialPipes);
        instanceList.add(new FileIterator(fileArr, FileIterator.STARTING_DIRECTORIES));
        System.out.println("Training 1");
        NaiveBayesTrainer naiveBayesTrainer = new NaiveBayesTrainer();
        System.out.println("data alphabet size " + instanceList.getDataAlphabet().size());
        System.out.println("target alphabet size " + instanceList.getTargetAlphabet().size());
        InstanceList instanceList2 = new InstanceList(serialPipes);
        instanceList2.add(new FileIterator(new String[]{"/usr/gob/users1/hough/melinda-bug2/update/subfdr"}, FileIterator.STARTING_DIRECTORIES));
        System.out.println("Training 2");
        System.out.println("data alphabet size " + instanceList2.getDataAlphabet().size());
        System.out.println("target alphabet size " + instanceList2.getTargetAlphabet().size());
    }

    public void testIncrementallyTrained() {
        System.out.println("testIncrementallyTrained");
        String[] strArr = {"src/edu/umass/cs/mallet/base/classify/tests/NaiveBayesData/learn/a", "src/edu/umass/cs/mallet/base/classify/tests/NaiveBayesData/learn/b"};
        File[] fileArr = new File[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            fileArr[i] = new File(strArr[i]);
        }
        SerialPipes serialPipes = new SerialPipes(new Pipe[]{new Target2Label(), new Input2CharSequence(), new CharSequence2TokenSequence(), new TokenSequenceLowercase(), new TokenSequenceRemoveStopwords(), new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector()});
        InstanceList instanceList = new InstanceList(serialPipes);
        instanceList.add(new FileIterator(fileArr, FileIterator.STARTING_DIRECTORIES));
        System.out.println("Training 1");
        NaiveBayesTrainer naiveBayesTrainer = new NaiveBayesTrainer();
        NaiveBayes naiveBayes = (NaiveBayes) naiveBayesTrainer.incrementalTrain(instanceList);
        Classification classify = naiveBayes.classify("Hello Everybody");
        Classification classify2 = naiveBayes.classify("Goodbye now");
        System.out.println("Initial Classification = ");
        classify.print();
        classify2.print();
        System.out.println("data alphabet " + naiveBayes.getAlphabet());
        System.out.println("label alphabet " + naiveBayes.getLabelAlphabet());
        System.out.println("data alphabet size " + instanceList.getDataAlphabet().size());
        System.out.println("target alphabet size " + instanceList.getTargetAlphabet().size());
        InstanceList instanceList2 = new InstanceList(serialPipes);
        instanceList2.add(new FileIterator(new String[]{"src/edu/umass/cs/mallet/base/classify/tests/NaiveBayesData/learn/b"}, FileIterator.STARTING_DIRECTORIES));
        System.out.println("Training 2");
        System.out.println("data alphabet size " + instanceList2.getDataAlphabet().size());
        System.out.println("target alphabet size " + instanceList2.getTargetAlphabet().size());
    }

    public void testEmptyStringBug() {
        System.out.println("testEmptyStringBug");
        String[] strArr = {"src/edu/umass/cs/mallet/base/classify/tests/NaiveBayesData/learn/a", "src/edu/umass/cs/mallet/base/classify/tests/NaiveBayesData/learn/b"};
        File[] fileArr = new File[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            fileArr[i] = new File(strArr[i]);
        }
        SerialPipes serialPipes = new SerialPipes(new Pipe[]{new Target2Label(), new Input2CharSequence(), new CharSequence2TokenSequence(), new TokenSequenceLowercase(), new TokenSequenceRemoveStopwords(), new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector()});
        InstanceList instanceList = new InstanceList(serialPipes);
        instanceList.add(new FileIterator(fileArr, FileIterator.STARTING_DIRECTORIES));
        System.out.println("Training 1");
        NaiveBayesTrainer naiveBayesTrainer = new NaiveBayesTrainer();
        NaiveBayes naiveBayes = (NaiveBayes) naiveBayesTrainer.incrementalTrain(instanceList);
        Classification classify = naiveBayes.classify("Hello Everybody");
        Classification classify2 = naiveBayes.classify("Goodbye now");
        System.out.println("Initial Classification = ");
        classify.print();
        classify2.print();
        System.out.println("data alphabet " + naiveBayes.getAlphabet());
        System.out.println("label alphabet " + naiveBayes.getLabelAlphabet());
        System.out.println("data alphabet size " + instanceList.getDataAlphabet().size());
        System.out.println("target alphabet size " + instanceList.getTargetAlphabet().size());
        InstanceList instanceList2 = new InstanceList(serialPipes);
        instanceList2.add(new FileIterator(new String[]{"src/edu/umass/cs/mallet/base/classify/tests/NaiveBayesData/learn/b"}, FileIterator.STARTING_DIRECTORIES, true));
        System.out.println("Training 2");
        System.out.println("data alphabet size " + instanceList2.getDataAlphabet().size());
        System.out.println("target alphabet size " + instanceList2.getTargetAlphabet().size());
        naiveBayes.classify("Goodbye now").print();
    }

    static Test suite() {
        return new TestSuite(TestNaiveBayes.class);
    }

    protected void setUp() {
    }

    public static void main(String[] strArr) {
        TestRunner.run(suite());
    }
}
