package org.tribuo.classification.sgd.crf;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.logging.Logger;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.sequence.ConfidencePredictingSequenceModel;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.sequence.SequenceExample;

/* loaded from: input_file:org/tribuo/classification/sgd/crf/CRFModel.class */
public class CRFModel extends ConfidencePredictingSequenceModel {
    private static final Logger logger = Logger.getLogger(CRFModel.class.getName());
    private static final long serialVersionUID = 2;
    private final CRFParameters parameters;
    private ConfidenceType confidenceType;

    /* loaded from: input_file:org/tribuo/classification/sgd/crf/CRFModel$ConfidenceType.class */
    public enum ConfidenceType {
        NONE,
        MULTIPLY,
        CONSTRAINED_BP
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public CRFModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, CRFParameters cRFParameters) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo);
        this.parameters = cRFParameters;
        this.confidenceType = ConfidenceType.NONE;
    }

    public void setConfidenceType(ConfidenceType confidenceType) {
        this.confidenceType = confidenceType;
    }

    public DenseVector getFeatureWeights(int i) {
        if (i >= 0 && i < this.featureIDMap.size()) {
            return this.parameters.getFeatureWeights(i);
        }
        logger.warning("Unknown feature");
        return new DenseVector(0);
    }

    public DenseVector getFeatureWeights(String str) {
        if (this.featureIDMap.getID(str) > -1) {
            return getFeatureWeights(this.featureIDMap.getID(str));
        }
        logger.warning("Unknown feature");
        return new DenseVector(0);
    }

    public List<Prediction<Label>> predict(SequenceExample<Label> sequenceExample) {
        SparseVector[] convert = convert(sequenceExample, this.featureIDMap);
        ArrayList arrayList = new ArrayList();
        if (this.confidenceType == ConfidenceType.MULTIPLY) {
            DenseVector[] predictMarginals = this.parameters.predictMarginals(convert);
            for (int i = 0; i < predictMarginals.length; i++) {
                double d = Double.NEGATIVE_INFINITY;
                Label label = null;
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                for (int i2 = 0; i2 < predictMarginals[i].size(); i2++) {
                    String label2 = this.outputIDMap.getOutput(i2).getLabel();
                    Label label3 = new Label(label2, predictMarginals[i].get(i2));
                    linkedHashMap.put(label2, label3);
                    if (label3.getScore() > d) {
                        d = label3.getScore();
                        label = label3;
                    }
                }
                arrayList.add(new Prediction(label, linkedHashMap, convert[i].numActiveElements(), sequenceExample.get(i), true));
            }
        } else {
            int[] predict = this.parameters.predict(convert);
            for (int i3 = 0; i3 < predict.length; i3++) {
                arrayList.add(new Prediction(this.outputIDMap.getOutput(predict[i3]), convert[i3].numActiveElements(), sequenceExample.get(i3)));
            }
        }
        return arrayList;
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        int size = i < 0 ? this.featureIDMap.size() + 1 : i;
        Comparator comparing = Comparator.comparing((v0) -> {
            return v0.getB();
        });
        int size2 = this.outputIDMap.size();
        int size3 = this.featureIDMap.size();
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < size2; i2++) {
            PriorityQueue priorityQueue = new PriorityQueue(size, comparing);
            for (int i3 = 0; i3 < size3; i3++) {
                Pair pair = new Pair(this.featureIDMap.get(i3).getName(), Double.valueOf(this.parameters.getWeight(i2, i3)));
                if (priorityQueue.size() < size) {
                    priorityQueue.offer(pair);
                } else if (comparing.compare(pair, priorityQueue.peek()) > 0) {
                    priorityQueue.poll();
                    priorityQueue.offer(pair);
                }
            }
            Pair pair2 = new Pair("BIAS", Double.valueOf(this.parameters.getBias(i2)));
            if (priorityQueue.size() < size) {
                priorityQueue.offer(pair2);
            } else if (comparing.compare(pair2, priorityQueue.peek()) > 0) {
                priorityQueue.poll();
                priorityQueue.offer(pair2);
            }
            ArrayList arrayList = new ArrayList();
            while (priorityQueue.size() > 0) {
                arrayList.add(priorityQueue.poll());
            }
            Collections.reverse(arrayList);
            hashMap.put(this.outputIDMap.getOutput(i2).getLabel(), arrayList);
        }
        return hashMap;
    }

    public <SUB extends ConfidencePredictingSequenceModel.Subsequence> List<Double> scoreSubsequences(SequenceExample<Label> sequenceExample, List<Prediction<Label>> list, List<SUB> list2) {
        if (this.confidenceType != ConfidenceType.CONSTRAINED_BP) {
            return ConfidencePredictingSequenceModel.multiplyWeights(list, list2);
        }
        ArrayList arrayList = new ArrayList();
        for (SUB sub : list2) {
            int[] iArr = new int[sub.length()];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = this.outputIDMap.getID(list.get(i + ((ConfidencePredictingSequenceModel.Subsequence) sub).begin).getOutput());
            }
            arrayList.add(new Chunk(((ConfidencePredictingSequenceModel.Subsequence) sub).begin, iArr));
        }
        return scoreChunks(sequenceExample, arrayList);
    }

    public List<Double> scoreChunks(SequenceExample<Label> sequenceExample, List<Chunk> list) {
        return this.parameters.predictConfidenceUsingCBP(convert(sequenceExample, this.featureIDMap), list);
    }

    public String generateWeightsString() {
        StringBuilder sb = new StringBuilder();
        Tensor[] tensorArr = this.parameters.get();
        sb.append("Biases = ");
        sb.append(tensorArr[0].toString());
        sb.append('\n');
        sb.append("Feature-Label weights = \n");
        sb.append(tensorArr[1].toString());
        sb.append('\n');
        sb.append("Label-Label weights = \n");
        sb.append(tensorArr[2].toString());
        sb.append('\n');
        return sb.toString();
    }

    public static <T extends Output<T>> SparseVector[] convert(SequenceExample<T> sequenceExample, ImmutableFeatureMap immutableFeatureMap) {
        int size = sequenceExample.size();
        if (size == 0) {
            throw new IllegalArgumentException("SequenceExample is empty, " + sequenceExample.toString());
        }
        SparseVector[] sparseVectorArr = new SparseVector[size];
        int i = 0;
        Iterator it = sequenceExample.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            sparseVectorArr[i] = SparseVector.createSparseVector(example, immutableFeatureMap, false);
            if (sparseVectorArr[i].numActiveElements() == 0) {
                throw new IllegalArgumentException("No features found in Example " + example.toString());
            }
            i++;
        }
        return sparseVectorArr;
    }

    public static Pair<int[], SparseVector[]> convert(SequenceExample<Label> sequenceExample, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo) {
        int size = sequenceExample.size();
        if (size == 0) {
            throw new IllegalArgumentException("SequenceExample is empty, " + sequenceExample.toString());
        }
        int[] iArr = new int[size];
        SparseVector[] sparseVectorArr = new SparseVector[size];
        int i = 0;
        Iterator it = sequenceExample.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            iArr[i] = immutableOutputInfo.getID(example.getOutput());
            sparseVectorArr[i] = SparseVector.createSparseVector(example, immutableFeatureMap, false);
            if (sparseVectorArr[i].numActiveElements() == 0) {
                throw new IllegalArgumentException("No features found in Example " + example.toString());
            }
            i++;
        }
        return new Pair<>(iArr, sparseVectorArr);
    }
}
