package org.tribuo.classification.sgd.crf;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.tribuo.classification.sgd.crf.ChainHelper;
import org.tribuo.math.Parameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.util.HeapMerger;
import org.tribuo.math.util.Merger;

/* loaded from: input_file:org/tribuo/classification/sgd/crf/CRFParameters.class */
public class CRFParameters implements Parameters, Serializable {
    private static final long serialVersionUID = 1;
    private final int numLabels;
    private final int numFeatures;
    private static final Merger merger = new HeapMerger();
    private Tensor[] weights = new Tensor[3];
    private DenseVector biases;
    private DenseMatrix featureLabelWeights;
    private DenseMatrix labelLabelWeights;

    /* JADX INFO: Access modifiers changed from: package-private */
    public CRFParameters(int i, int i2) {
        this.biases = new DenseVector(i2);
        this.featureLabelWeights = new DenseMatrix(i2, i);
        this.labelLabelWeights = new DenseMatrix(i2, i2);
        this.weights[0] = this.biases;
        this.weights[1] = this.featureLabelWeights;
        this.weights[2] = this.labelLabelWeights;
        this.numLabels = i2;
        this.numFeatures = i;
    }

    public DenseVector getFeatureWeights(int i) {
        return this.featureLabelWeights.getColumn(i);
    }

    public double getBias(int i) {
        return this.biases.get(i);
    }

    public double getWeight(int i, int i2) {
        return this.featureLabelWeights.get(i, i2);
    }

    public DenseVector[] getLocalScores(SparseVector[] sparseVectorArr) {
        DenseVector[] denseVectorArr = new DenseVector[sparseVectorArr.length];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            DenseVector leftMultiply = this.featureLabelWeights.leftMultiply(sparseVectorArr[i]);
            leftMultiply.intersectAndAddInPlace(this.biases);
            denseVectorArr[i] = leftMultiply;
        }
        return denseVectorArr;
    }

    public ChainHelper.ChainCliqueValues getCliqueValues(SparseVector[] sparseVectorArr) {
        return new ChainHelper.ChainCliqueValues(getLocalScores(sparseVectorArr), this.labelLabelWeights);
    }

    public int[] predict(SparseVector[] sparseVectorArr) {
        return ChainHelper.viterbi(getCliqueValues(sparseVectorArr)).mapValues;
    }

    public DenseVector[] predictMarginals(SparseVector[] sparseVectorArr) {
        ChainHelper.ChainBPResults beliefPropagation = ChainHelper.beliefPropagation(getCliqueValues(sparseVectorArr));
        DenseVector[] denseVectorArr = new DenseVector[sparseVectorArr.length];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            denseVectorArr[i] = beliefPropagation.alphas[i].add(beliefPropagation.betas[i]);
            denseVectorArr[i].expNormalize(beliefPropagation.logZ);
        }
        return denseVectorArr;
    }

    public List<Double> predictConfidenceUsingCBP(SparseVector[] sparseVectorArr, List<Chunk> list) {
        ChainHelper.ChainCliqueValues cliqueValues = getCliqueValues(sparseVectorArr);
        double d = ChainHelper.beliefPropagation(cliqueValues).logZ;
        int[] iArr = new int[sparseVectorArr.length];
        ArrayList arrayList = new ArrayList();
        for (Chunk chunk : list) {
            Arrays.fill(iArr, -1);
            chunk.unpack(iArr);
            arrayList.add(Double.valueOf(Math.exp(ChainHelper.constrainedBeliefPropagation(cliqueValues, iArr) - d)));
        }
        return arrayList;
    }

    public Pair<Double, Tensor[]> valueAndGradient(SparseVector[] sparseVectorArr, int[] iArr) {
        ChainHelper.ChainCliqueValues cliqueValues = getCliqueValues(sparseVectorArr);
        ChainHelper.ChainBPResults beliefPropagation = ChainHelper.beliefPropagation(cliqueValues);
        double d = beliefPropagation.logZ;
        DenseVector[] denseVectorArr = beliefPropagation.alphas;
        SGDVector[] sGDVectorArr = beliefPropagation.betas;
        Tensor[] tensorArr = new Tensor[3];
        DenseSparseMatrix[] denseSparseMatrixArr = new DenseSparseMatrix[sparseVectorArr.length];
        tensorArr[0] = new DenseVector(this.biases.size());
        DenseMatrix denseMatrix = new DenseMatrix(this.numLabels, this.numLabels);
        tensorArr[2] = denseMatrix;
        double d2 = -d;
        for (int i = 0; i < sparseVectorArr.length; i++) {
            int i2 = iArr[i];
            DenseVector denseVector = cliqueValues.localValues[i];
            d2 += denseVector.get(i2);
            DenseVector denseVector2 = denseVectorArr[i];
            SGDVector sGDVector = sGDVectorArr[i];
            DenseVector add = denseVector2.add(sGDVector);
            add.expNormalize(d);
            add.scaleInPlace(-1.0d);
            add.add(i2, 1.0d);
            tensorArr[0].intersectAndAddInPlace(add);
            denseSparseMatrixArr[i] = (DenseSparseMatrix) add.outer(sparseVectorArr[i]);
            if (i >= 1) {
                DenseVector denseVector3 = denseVectorArr[i - 1];
                for (int i3 = 0; i3 < this.numLabels; i3++) {
                    for (int i4 = 0; i4 < this.numLabels; i4++) {
                        denseMatrix.add(i3, i4, -Math.exp((((denseVector3.get(i3) + this.labelLabelWeights.get(i3, i4)) + sGDVector.get(i4)) + denseVector.get(i4)) - d));
                    }
                }
                int i5 = iArr[i - 1];
                d2 += this.labelLabelWeights.get(i5, i2);
                denseMatrix.add(i5, i2, 1.0d);
            }
        }
        tensorArr[1] = merger.merge(denseSparseMatrixArr);
        return new Pair<>(Double.valueOf(d2), tensorArr);
    }

    public Tensor[] getEmptyCopy() {
        return new Tensor[]{new DenseVector(this.biases.size()), new DenseMatrix(this.featureLabelWeights.getDimension1Size(), this.featureLabelWeights.getDimension2Size()), new DenseMatrix(this.labelLabelWeights.getDimension1Size(), this.labelLabelWeights.getDimension2Size())};
    }

    public Tensor[] get() {
        return this.weights;
    }

    public void set(Tensor[] tensorArr) {
        if (tensorArr.length == this.weights.length) {
            this.weights = tensorArr;
            this.biases = this.weights[0];
            this.featureLabelWeights = this.weights[1];
            this.labelLabelWeights = this.weights[2];
        }
    }

    public void update(Tensor[] tensorArr) {
        for (int i = 0; i < tensorArr.length; i++) {
            this.weights[i].intersectAndAddInPlace(tensorArr[i]);
        }
    }

    public Tensor[] merge(Tensor[][] tensorArr, int i) {
        Tensor denseVector = new DenseVector(this.biases.size());
        DenseSparseMatrix[] denseSparseMatrixArr = new DenseSparseMatrix[i];
        Tensor denseMatrix = new DenseMatrix(this.labelLabelWeights.getDimension1Size(), this.labelLabelWeights.getDimension2Size());
        for (int i2 = 0; i2 < denseSparseMatrixArr.length; i2++) {
            denseVector.intersectAndAddInPlace(tensorArr[i2][0]);
            denseSparseMatrixArr[i2] = (DenseSparseMatrix) tensorArr[i2][1];
            denseMatrix.intersectAndAddInPlace(tensorArr[i2][2]);
        }
        return new Tensor[]{denseVector, merger.merge(denseSparseMatrixArr), denseMatrix};
    }
}
