package cc.mallet.fst.semi_supervised;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import java.util.ArrayList;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/fst/semi_supervised/CRFTrainerByGE.class */
public class CRFTrainerByGE extends TransducerTrainer {
    private static Logger logger;
    private static final int DEFAULT_NUM_RESETS = 1;
    private static final int DEFAULT_GPV = 10;
    private boolean converged;
    private int iteration;
    private int numThreads;
    private int numResets;
    private double gaussianPriorVariance;
    private ArrayList<GEConstraint> constraints;
    private CRF crf;
    private StateLabelMap stateLabelMap;
    private CRFOptimizableByGE optimizable;
    private Optimizer optimizer;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !CRFTrainerByGE.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(CRFTrainerByGE.class.getName());
    }

    public CRFTrainerByGE(CRF crf, ArrayList<GEConstraint> arrayList) {
        this(crf, arrayList, 1);
    }

    public CRFTrainerByGE(CRF crf, ArrayList<GEConstraint> arrayList, int i) {
        this.converged = false;
        this.iteration = 0;
        this.constraints = arrayList;
        this.crf = crf;
        this.numThreads = i;
        this.numResets = 1;
        this.gaussianPriorVariance = 10.0d;
        this.stateLabelMap = new StateLabelMap(crf.getOutputAlphabet(), true);
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public int getIteration() {
        return this.iteration;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public Transducer getTransducer() {
        return this.crf;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean isFinishedTraining() {
        return this.converged;
    }

    public void setGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
    }

    public void setNumResets(int i) {
        this.numResets = i;
    }

    public void setStateLabelMap(StateLabelMap stateLabelMap) {
        this.stateLabelMap = stateLabelMap;
    }

    public void setOptimizable(Optimizer optimizer) {
        this.optimizer = optimizer;
    }

    public Optimizable.ByGradientValue getOptimizable(InstanceList instanceList) {
        if (this.optimizable == null) {
            this.optimizable = new CRFOptimizableByGE(this.crf, this.constraints, instanceList, this.stateLabelMap, this.numThreads);
            this.optimizable.setGaussianPriorVariance(this.gaussianPriorVariance);
        }
        return this.optimizable;
    }

    public Optimizer getOptimizer(Optimizable.ByGradientValue byGradientValue) {
        if (this.optimizer == null) {
            this.optimizer = new LimitedMemoryBFGS(byGradientValue);
        }
        return this.optimizer;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean train(InstanceList instanceList, int i) {
        if (!$assertionsDisabled && this.constraints.size() <= 0) {
            throw new AssertionError();
        }
        if (this.constraints.size() == 0) {
            throw new RuntimeException("No constraints specified!");
        }
        getOptimizable(instanceList);
        getOptimizer(this.optimizable);
        if (this.optimizer instanceof LimitedMemoryBFGS) {
            ((LimitedMemoryBFGS) this.optimizer).reset();
        }
        this.converged = false;
        logger.info("CRF about to train with " + i + " iterations");
        int i2 = 0;
        for (int i3 = 0; i3 < this.numResets + 1; i3++) {
            while (true) {
                if (i2 >= i) {
                    break;
                }
                try {
                    this.converged = this.optimizer.optimize(1);
                    this.iteration++;
                    logger.info("CRF finished one iteration of maximizer, i=" + i2);
                    runEvaluators();
                } catch (IllegalArgumentException e) {
                    e.printStackTrace();
                    logger.info("Catching exception; saying converged.");
                    this.converged = true;
                } catch (Exception e2) {
                    e2.printStackTrace();
                    logger.info("Catching exception; saying converged.");
                    this.converged = true;
                }
                if (this.converged) {
                    logger.info("CRF training has converged, i=" + i2);
                    break;
                }
                i2++;
            }
            if (this.optimizer instanceof LimitedMemoryBFGS) {
                ((LimitedMemoryBFGS) this.optimizer).reset();
            }
        }
        this.optimizable.shutdown();
        return this.converged;
    }
}
