package de.julielab.gene.candidateretrieval.scoring;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Token2FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import java.util.ArrayList;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/julielab/gene/candidateretrieval/scoring/MaxEntScorerML.class */
public class MaxEntScorerML {
    private static final Logger LOGGER = LoggerFactory.getLogger(MaxEntScorerML.class);

    public InstanceList makeInstances(ArrayList<String[]> arrayList, Pipe pipe) {
        LOGGER.debug("makeInstances() - making instances for pairs with old pipe ...");
        InstanceList instanceList = new InstanceList(pipe);
        for (int i = 0; i < arrayList.size(); i++) {
            instanceList.addThruPipe(new Instance(arrayList.get(i), "", "", ""));
        }
        return instanceList;
    }

    public InstanceList makeInstances(ArrayList<String[]> arrayList) {
        LOGGER.debug("makeInstances() - making instances for pairs with new pipe ...");
        InstanceList instanceList = new InstanceList(new SerialPipes(new Pipe[]{new MaxEntScorerFeaturePipe(), new Token2FeatureVector()}));
        for (int i = 0; i < arrayList.size(); i++) {
            instanceList.addThruPipe(new Instance(arrayList.get(i), "", "", ""));
        }
        return instanceList;
    }

    public Classifier train(InstanceList instanceList) {
        LOGGER.debug("train() - training the model from " + instanceList.size() + " training examples ...");
        return new MaxEntTrainer().train(instanceList);
    }

    public double predict(Instance instance, Classifier classifier) {
        return getProbabilityTrueClass(classifier.classify(instance));
    }

    public void eval(Classifier classifier, InstanceList instanceList) {
        Iterator it = classifier.classify(instanceList).iterator();
        while (it.hasNext()) {
            Classification classification = (Classification) it.next();
            Labeling labeling = classification.getLabeling();
            double probabilityTrueClass = getProbabilityTrueClass(classification);
            System.out.println("           pair: " + classification.getInstance().getSource());
            System.out.println("predicted score: " + probabilityTrueClass);
            System.out.println("  correct class: " + classification.getInstance().getName());
            System.out.println("predicted class: " + labeling.getBestLabel() + "\n");
        }
    }

    private double getProbabilityTrueClass(Classification classification) {
        Labeling labeling = classification.getLabeling();
        return labeling.value(labeling.getLabelAlphabet().lookupLabel("TRUE"));
    }
}
