package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import java.util.ArrayList;
import java.util.Collections;

/* loaded from: input_file:cc/mallet/fst/CRFTrainerByStochasticGradient.class */
public class CRFTrainerByStochasticGradient extends TransducerTrainer.ByInstanceIncrements {
    protected CRF crf;
    protected double learningRate;
    protected double t;
    protected double lambda;
    protected int iterationCount = 0;
    protected boolean converged = false;
    protected CRF.Factors expectations;
    protected CRF.Factors constraints;
    static final /* synthetic */ boolean $assertionsDisabled;

    public CRFTrainerByStochasticGradient(CRF crf, InstanceList instanceList) {
        this.crf = crf;
        this.expectations = new CRF.Factors(crf);
        this.constraints = new CRF.Factors(crf);
        setLearningRateByLikelihood(instanceList);
    }

    public CRFTrainerByStochasticGradient(CRF crf, double d) {
        this.crf = crf;
        this.learningRate = d;
        this.expectations = new CRF.Factors(crf);
        this.constraints = new CRF.Factors(crf);
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public int getIteration() {
        return this.iterationCount;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public Transducer getTransducer() {
        return this.crf;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean isFinishedTraining() {
        return this.converged;
    }

    public void setLearningRateByLikelihood(InstanceList instanceList) {
        double d = Double.NEGATIVE_INFINITY;
        double d2 = Double.NEGATIVE_INFINITY;
        double d3 = 5.0E-11d;
        while (d3 < 1.0d) {
            d3 *= 2.0d;
            this.crf.parameters.zero();
            double trainSample = trainSample(instanceList, 5, d3) - computeLikelihood(instanceList);
            System.out.println("likelihood change = " + trainSample + " for learningrate=" + d3);
            if (trainSample > d2) {
                d2 = trainSample;
                d = d3;
            }
        }
        this.crf.parameters.zero();
        double d4 = d / 2.0d;
        System.out.println("Setting learning rate to " + d4);
        setLearningRate(d4);
    }

    private double trainSample(InstanceList instanceList, int i, double d) {
        double size = instanceList.size();
        double d2 = 1.0d / (size * d);
        double d3 = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < i; i2++) {
            d3 = 0.0d;
            for (int i3 = 0; i3 < instanceList.size(); i3++) {
                d3 += trainIncrementalLikelihood(instanceList.get(i3), 1.0d / (size * d2));
                d2 += 1.0d;
            }
        }
        return d3;
    }

    private double computeLikelihood(InstanceList instanceList) {
        double d = 0.0d;
        for (int i = 0; i < instanceList.size(); i++) {
            Instance instance = instanceList.get(i);
            FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) instance.getData();
            d = (d + new SumLatticeDefault(this.crf, featureVectorSequence, (Sequence) instance.getTarget(), null).getTotalWeight()) - new SumLatticeDefault(this.crf, featureVectorSequence, null, null).getTotalWeight();
        }
        this.constraints.zero();
        this.expectations.zero();
        return d;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean train(InstanceList instanceList, int i) {
        return train(instanceList, i, 1);
    }

    public boolean train(InstanceList instanceList, int i, int i2) {
        if (!$assertionsDisabled && !this.expectations.structureMatches(this.crf.parameters)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !this.constraints.structureMatches(this.crf.parameters)) {
            throw new AssertionError();
        }
        this.lambda = 1.0d / instanceList.size();
        this.t = 1.0d / (this.lambda * this.learningRate);
        this.converged = false;
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < instanceList.size(); i3++) {
            arrayList.add(Integer.valueOf(i3));
        }
        double d = Double.NEGATIVE_INFINITY;
        while (true) {
            int i4 = i;
            i--;
            if (i4 <= 0) {
                break;
            }
            this.iterationCount++;
            Collections.shuffle(arrayList);
            double d2 = 0.0d;
            for (int i5 = 0; i5 < instanceList.size(); i5++) {
                this.learningRate = 1.0d / (this.lambda * this.t);
                d2 += trainIncrementalLikelihood(instanceList.get(((Integer) arrayList.get(i5)).intValue()));
                this.t += 1.0d;
            }
            System.out.println("loglikelihood[" + i + "] = " + d2);
            if (Math.abs(d2 - d) < 0.001d) {
                this.converged = true;
                break;
            }
            d = d2;
            Runtime.getRuntime().gc();
            if (this.iterationCount % i2 == 0) {
                runEvaluators();
            }
        }
        return this.converged;
    }

    @Override // cc.mallet.fst.TransducerTrainer.ByIncrements
    public boolean trainIncremental(InstanceList instanceList) {
        train(instanceList, 1);
        return false;
    }

    @Override // cc.mallet.fst.TransducerTrainer.ByInstanceIncrements
    public boolean trainIncremental(Instance instance) {
        if (!$assertionsDisabled && !this.expectations.structureMatches(this.crf.parameters)) {
            throw new AssertionError();
        }
        trainIncrementalLikelihood(instance);
        return false;
    }

    public double trainIncrementalLikelihood(Instance instance) {
        return trainIncrementalLikelihood(instance, this.learningRate);
    }

    public double trainIncrementalLikelihood(Instance instance, double d) {
        this.constraints.zero();
        this.expectations.zero();
        FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) instance.getData();
        Sequence sequence = (Sequence) instance.getTarget();
        CRF crf = this.crf;
        CRF.Factors factors = this.constraints;
        factors.getClass();
        double totalWeight = new SumLatticeDefault(crf, featureVectorSequence, sequence, new CRF.Factors.Incrementor()).getTotalWeight();
        CRF crf2 = this.crf;
        CRF.Factors factors2 = this.expectations;
        factors2.getClass();
        double totalWeight2 = totalWeight - new SumLatticeDefault(crf2, featureVectorSequence, null, new CRF.Factors.Incrementor()).getTotalWeight();
        this.constraints.plusEquals(this.expectations, -1.0d);
        this.crf.parameters.plusEquals(this.constraints, d, true);
        return totalWeight2;
    }

    static {
        $assertionsDisabled = !CRFTrainerByStochasticGradient.class.desiredAssertionStatus();
    }
}
