package cc.mallet.classify;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import cc.mallet.util.Maths;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/classify/MaxEntOptimizableByLabelLikelihood.class */
public class MaxEntOptimizableByLabelLikelihood implements Optimizable.ByGradientValue {
    private static Logger logger;
    private static Logger progressLogger;
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0d;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2d;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0d;
    static final Class DEFAULT_MAXIMIZER_CLASS;
    double[] parameters;
    double[] constraints;
    double[] cachedGradient;
    MaxEnt theClassifier;
    InstanceList trainingList;
    double cachedValue;
    boolean cachedValueStale;
    boolean cachedGradientStale;
    int numLabels;
    int numFeatures;
    int defaultFeatureIndex;
    FeatureSelection featureSelection;
    FeatureSelection[] perLabelFeatureSelection;
    static final /* synthetic */ boolean $assertionsDisabled;
    boolean usingHyperbolicPrior = false;
    boolean usingGaussianPrior = true;
    double gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
    double hyperbolicPriorSlope = DEFAULT_HYPERBOLIC_PRIOR_SLOPE;
    double hyperbolicPriorSharpness = DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS;
    Class maximizerClass = DEFAULT_MAXIMIZER_CLASS;
    int numGetValueCalls = 0;
    int numGetValueGradientCalls = 0;

    public MaxEntOptimizableByLabelLikelihood() {
    }

    public MaxEntOptimizableByLabelLikelihood(InstanceList instanceList, MaxEnt maxEnt) {
        this.trainingList = instanceList;
        Alphabet dataAlphabet = instanceList.getDataAlphabet();
        LabelAlphabet labelAlphabet = (LabelAlphabet) instanceList.getTargetAlphabet();
        labelAlphabet.stopGrowth();
        this.numLabels = labelAlphabet.size();
        this.numFeatures = dataAlphabet.size() + 1;
        this.defaultFeatureIndex = this.numFeatures - 1;
        this.parameters = new double[this.numLabels * this.numFeatures];
        this.constraints = new double[this.numLabels * this.numFeatures];
        this.cachedGradient = new double[this.numLabels * this.numFeatures];
        Arrays.fill(this.parameters, 0.0d);
        Arrays.fill(this.constraints, 0.0d);
        Arrays.fill(this.cachedGradient, 0.0d);
        this.featureSelection = instanceList.getFeatureSelection();
        this.perLabelFeatureSelection = instanceList.getPerLabelFeatureSelection();
        if (this.featureSelection != null) {
            this.featureSelection.add(this.defaultFeatureIndex);
        }
        if (this.perLabelFeatureSelection != null) {
            for (int i = 0; i < this.perLabelFeatureSelection.length; i++) {
                this.perLabelFeatureSelection[i].add(this.defaultFeatureIndex);
            }
        }
        if (!$assertionsDisabled && this.featureSelection != null && this.perLabelFeatureSelection != null) {
            throw new AssertionError();
        }
        if (maxEnt != null) {
            this.theClassifier = maxEnt;
            this.parameters = this.theClassifier.parameters;
            this.featureSelection = this.theClassifier.featureSelection;
            this.perLabelFeatureSelection = this.theClassifier.perClassFeatureSelection;
            this.defaultFeatureIndex = this.theClassifier.defaultFeatureIndex;
            if (!$assertionsDisabled && maxEnt.getInstancePipe() != instanceList.getPipe()) {
                throw new AssertionError();
            }
        } else if (this.theClassifier == null) {
            this.theClassifier = new MaxEnt(instanceList.getPipe(), this.parameters, this.featureSelection, this.perLabelFeatureSelection);
        }
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        logger.fine("Number of instances in training list = " + this.trainingList.size());
        Iterator<Instance> it = this.trainingList.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            double instanceWeight = this.trainingList.getInstanceWeight(next);
            Labeling labeling = next.getLabeling();
            if (labeling != null) {
                FeatureVector featureVector = (FeatureVector) next.getData();
                Alphabet alphabet = featureVector.getAlphabet();
                if (!$assertionsDisabled && featureVector.getAlphabet() != dataAlphabet) {
                    throw new AssertionError();
                }
                int bestIndex = labeling.getBestIndex();
                MatrixOps.rowPlusEquals(this.constraints, this.numFeatures, bestIndex, featureVector, instanceWeight);
                if (!$assertionsDisabled && Double.isNaN(instanceWeight)) {
                    throw new AssertionError("instanceWeight is NaN");
                }
                if (!$assertionsDisabled && Double.isNaN(bestIndex)) {
                    throw new AssertionError("bestIndex is NaN");
                }
                boolean z = false;
                for (int i2 = 0; i2 < featureVector.numLocations(); i2++) {
                    if (Double.isNaN(featureVector.valueAtLocation(i2))) {
                        logger.info("NaN for feature " + alphabet.lookupObject(featureVector.indexAtLocation(i2)).toString());
                        z = true;
                    }
                }
                if (z) {
                    logger.info("NaN in instance: " + next.getName());
                }
                double[] dArr = this.constraints;
                int i3 = (bestIndex * this.numFeatures) + this.defaultFeatureIndex;
                dArr[i3] = dArr[i3] + (DEFAULT_GAUSSIAN_PRIOR_VARIANCE * instanceWeight);
            }
        }
    }

    public MaxEnt getClassifier() {
        return this.theClassifier;
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.parameters[i];
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        this.parameters[i] = d;
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        if (dArr == null || dArr.length != this.parameters.length) {
            dArr = new double[this.parameters.length];
        }
        System.arraycopy(this.parameters, 0, dArr, 0, this.parameters.length);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        if (!$assertionsDisabled && dArr == null) {
            throw new AssertionError();
        }
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        if (dArr.length != this.parameters.length) {
            this.parameters = new double[dArr.length];
        }
        System.arraycopy(dArr, 0, this.parameters, 0, dArr.length);
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        if (this.cachedValueStale) {
            this.numGetValueCalls++;
            this.cachedValue = 0.0d;
            this.cachedGradientStale = true;
            MatrixOps.setAll(this.cachedGradient, 0.0d);
            double[] dArr = new double[this.trainingList.getTargetAlphabet().size()];
            Iterator<Instance> it = this.trainingList.iterator();
            int i = 0;
            while (it.hasNext()) {
                i++;
                Instance next = it.next();
                double instanceWeight = this.trainingList.getInstanceWeight(next);
                Labeling labeling = next.getLabeling();
                if (labeling != null) {
                    this.theClassifier.getClassificationScores(next, dArr);
                    FeatureVector featureVector = (FeatureVector) next.getData();
                    int bestIndex = labeling.getBestIndex();
                    double d = -(instanceWeight * Math.log(dArr[bestIndex]));
                    if (Double.isNaN(d)) {
                        logger.fine("MaxEntTrainer: Instance " + next.getName() + "has NaN value. log(scores)= " + Math.log(dArr[bestIndex]) + " scores = " + dArr[bestIndex] + " has instance weight = " + instanceWeight);
                    }
                    if (Double.isInfinite(d)) {
                        logger.warning("Instance " + next.getSource() + " has infinite value; skipping value and gradient");
                        this.cachedValue -= d;
                        this.cachedValueStale = false;
                        return -d;
                    }
                    this.cachedValue += d;
                    for (int i2 = 0; i2 < dArr.length; i2++) {
                        if (dArr[i2] != 0.0d) {
                            if (!$assertionsDisabled && Double.isInfinite(dArr[i2])) {
                                throw new AssertionError();
                            }
                            MatrixOps.rowPlusEquals(this.cachedGradient, this.numFeatures, i2, featureVector, (-instanceWeight) * dArr[i2]);
                            double[] dArr2 = this.cachedGradient;
                            int i3 = (this.numFeatures * i2) + this.defaultFeatureIndex;
                            dArr2[i3] = dArr2[i3] + ((-instanceWeight) * dArr[i2]);
                        }
                    }
                }
            }
            double d2 = 0.0d;
            if (this.usingHyperbolicPrior) {
                for (int i4 = 0; i4 < this.numLabels; i4++) {
                    for (int i5 = 0; i5 < this.numFeatures; i5++) {
                        d2 += (this.hyperbolicPriorSlope / this.hyperbolicPriorSharpness) * Math.log(Maths.cosh(this.hyperbolicPriorSharpness * this.parameters[(i4 * this.numFeatures) + i5]));
                    }
                }
            } else if (this.usingGaussianPrior) {
                for (int i6 = 0; i6 < this.numLabels; i6++) {
                    for (int i7 = 0; i7 < this.numFeatures; i7++) {
                        double d3 = this.parameters[(i6 * this.numFeatures) + i7];
                        d2 += (d3 * d3) / (2.0d * this.gaussianPriorVariance);
                    }
                }
            }
            double d4 = this.cachedValue;
            this.cachedValue += d2;
            this.cachedValue *= -1.0d;
            this.cachedValueStale = false;
            progressLogger.info("Value (labelProb=" + d4 + " prior=" + d2 + ") loglikelihood = " + this.cachedValue);
        }
        return this.cachedValue;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        if (this.cachedGradientStale) {
            this.numGetValueGradientCalls++;
            if (this.cachedValueStale) {
                getValue();
            }
            MatrixOps.plusEquals(this.cachedGradient, this.constraints);
            if (this.usingHyperbolicPrior) {
                throw new UnsupportedOperationException("Hyperbolic prior not yet implemented.");
            }
            if (this.usingGaussianPrior) {
                MatrixOps.plusEquals(this.cachedGradient, this.parameters, (-1.0d) / this.gaussianPriorVariance);
            }
            MatrixOps.substitute(this.cachedGradient, Double.NEGATIVE_INFINITY, 0.0d);
            if (this.perLabelFeatureSelection == null) {
                for (int i = 0; i < this.numLabels; i++) {
                    MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, i, 0.0d, this.featureSelection, false);
                }
            } else {
                for (int i2 = 0; i2 < this.numLabels; i2++) {
                    MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, i2, 0.0d, this.perLabelFeatureSelection[i2], false);
                }
            }
            this.cachedGradientStale = false;
        }
        if (!$assertionsDisabled && (dArr == null || dArr.length != this.parameters.length)) {
            throw new AssertionError();
        }
        System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
    }

    public int getValueGradientCalls() {
        return this.numGetValueGradientCalls;
    }

    public int getValueCalls() {
        return this.numGetValueCalls;
    }

    public MaxEntOptimizableByLabelLikelihood useGaussianPrior() {
        this.usingGaussianPrior = true;
        this.usingHyperbolicPrior = false;
        return this;
    }

    public MaxEntOptimizableByLabelLikelihood useHyperbolicPrior() {
        this.usingGaussianPrior = false;
        this.usingHyperbolicPrior = true;
        return this;
    }

    public MaxEntOptimizableByLabelLikelihood useNoPrior() {
        this.usingGaussianPrior = false;
        this.usingHyperbolicPrior = false;
        return this;
    }

    public MaxEntOptimizableByLabelLikelihood setGaussianPriorVariance(double d) {
        this.usingGaussianPrior = true;
        this.usingHyperbolicPrior = false;
        this.gaussianPriorVariance = d;
        return this;
    }

    public MaxEntOptimizableByLabelLikelihood setHyperbolicPriorSlope(double d) {
        this.usingGaussianPrior = false;
        this.usingHyperbolicPrior = true;
        this.hyperbolicPriorSlope = d;
        return this;
    }

    public MaxEntOptimizableByLabelLikelihood setHyperbolicPriorSharpness(double d) {
        this.usingGaussianPrior = false;
        this.usingHyperbolicPrior = true;
        this.hyperbolicPriorSharpness = d;
        return this;
    }

    static {
        $assertionsDisabled = !MaxEntOptimizableByLabelLikelihood.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(MaxEntOptimizableByLabelLikelihood.class.getName());
        progressLogger = MalletProgressMessageLogger.getLogger(MaxEntOptimizableByLabelLikelihood.class.getName() + "-pl");
        DEFAULT_MAXIMIZER_CLASS = LimitedMemoryBFGS.class;
    }
}
