package cc.mallet.fst;

import cc.mallet.classify.MCMaxEntTrainer;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.ExpGain;
import cc.mallet.types.FeatureInducer;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.GradientGain;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.LabelVector;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Random;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/fst/CRFTrainerByLabelLikelihood.class */
public class CRFTrainerByLabelLikelihood extends TransducerTrainer implements TransducerTrainer.ByOptimization {
    private static Logger logger;
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0d;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2d;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0d;
    CRF crf;
    CRFOptimizableByLabelLikelihood ocrf;
    Optimizer opt;
    boolean converged;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 1;
    static final int NULL_INTEGER = -1;
    static final /* synthetic */ boolean $assertionsDisabled;
    private int minConvRounds = 5;
    int iterationCount = 0;
    boolean usingHyperbolicPrior = false;
    double gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
    double hyperbolicPriorSlope = DEFAULT_HYPERBOLIC_PRIOR_SLOPE;
    double hyperbolicPriorSharpness = DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS;
    boolean useSparseWeights = true;
    boolean useNoWeights = false;
    private transient boolean useSomeUnsupportedTrick = true;
    private int cachedValueWeightsStamp = -1;
    private int cachedGradientWeightsStamp = -1;
    private int cachedWeightsStructureStamp = -1;
    public boolean printGradient = false;

    public void setMinConvRounds(int i) {
        this.minConvRounds = i;
    }

    public int getMinConvRounds() {
        return this.minConvRounds;
    }

    public CRFTrainerByLabelLikelihood(CRF crf) {
        this.crf = crf;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public Transducer getTransducer() {
        return this.crf;
    }

    public CRF getCRF() {
        return this.crf;
    }

    @Override // cc.mallet.fst.TransducerTrainer.ByOptimization
    public Optimizer getOptimizer() {
        return this.opt;
    }

    public boolean isConverged() {
        return this.converged;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean isFinishedTraining() {
        return this.converged;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public int getIteration() {
        return this.iterationCount;
    }

    public void setAddNoFactors(boolean z) {
        this.useNoWeights = z;
    }

    public CRFOptimizableByLabelLikelihood getOptimizableCRF(InstanceList instanceList) {
        if (this.cachedWeightsStructureStamp != this.crf.weightsStructureChangeStamp) {
            if (!this.useNoWeights) {
                if (this.useSparseWeights) {
                    this.crf.setWeightsDimensionAsIn(instanceList, this.useSomeUnsupportedTrick);
                } else {
                    this.crf.setWeightsDimensionDensely();
                }
            }
            this.ocrf = null;
            this.cachedWeightsStructureStamp = this.crf.weightsStructureChangeStamp;
        }
        if (this.ocrf == null || this.ocrf.trainingSet != instanceList) {
            this.ocrf = new CRFOptimizableByLabelLikelihood(this.crf, instanceList);
            this.ocrf.setGaussianPriorVariance(this.gaussianPriorVariance);
            this.ocrf.setHyperbolicPriorSharpness(this.hyperbolicPriorSharpness);
            this.ocrf.setHyperbolicPriorSlope(this.hyperbolicPriorSlope);
            this.ocrf.setUseHyperbolicPrior(this.usingHyperbolicPrior);
            this.opt = null;
        }
        return this.ocrf;
    }

    public Optimizer getOptimizer(InstanceList instanceList) {
        getOptimizableCRF(instanceList);
        if (this.opt == null || this.ocrf != this.opt.getOptimizable()) {
            this.opt = new LimitedMemoryBFGS(this.ocrf);
        }
        return this.opt;
    }

    public boolean trainIncremental(InstanceList instanceList) {
        return train(instanceList, Integer.MAX_VALUE);
    }

    public boolean trainOptimized(InstanceList instanceList) {
        return trainOptimized(instanceList, Integer.MAX_VALUE);
    }

    public boolean trainOptimized(InstanceList instanceList, int i) {
        if (i <= 0) {
            return false;
        }
        if (!$assertionsDisabled && instanceList.size() <= 0) {
            throw new AssertionError();
        }
        getOptimizableCRF(instanceList);
        getOptimizer(instanceList);
        int i2 = 0;
        boolean z = false;
        boolean z2 = false;
        logger.info("CRF about to train with " + i + " iterations");
        int i3 = 0;
        while (true) {
            if (i3 >= i) {
                break;
            }
            try {
                z = this.opt.optimize(1);
                this.iterationCount++;
                logger.info("CRF finished one iteration of maximizer, i=" + i3);
                runEvaluators();
            } catch (OptimizationException e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                z2 = true;
            }
            if (!z) {
                i2 = 0;
            } else {
                if (i2 >= this.minConvRounds) {
                    logger.info("CRF training has converged by optimizer after " + i2 + " successive convergence rounds, i=" + i3);
                    this.converged = true;
                    break;
                }
                logger.info("CRF optimizer converged, but need more successive convergence rounds, succConv=" + i2);
                i2++;
            }
            if (z2) {
                logger.info("CRF training has converged by exception, i=" + i3);
                this.converged = true;
                break;
            }
            i3++;
        }
        return this.converged;
    }

    public boolean trainOptimized(InstanceList instanceList, int i, double[] dArr) {
        int i2 = 0;
        if (!$assertionsDisabled && dArr.length <= 0) {
            throw new AssertionError();
        }
        boolean z = false;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (!$assertionsDisabled && dArr[i3] > DEFAULT_GAUSSIAN_PRIOR_VARIANCE) {
                throw new AssertionError();
            }
            logger.info("Training on " + dArr[i3] + "% of the data this round.");
            z = dArr[i3] == DEFAULT_GAUSSIAN_PRIOR_VARIANCE ? trainOptimized(instanceList, i) : trainOptimized(instanceList.split(new Random(serialVersionUID), new double[]{dArr[i3], DEFAULT_GAUSSIAN_PRIOR_VARIANCE - dArr[i3]})[0], i);
            i2 += i;
        }
        return z;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean train(InstanceList instanceList, int i) {
        if (i <= 0) {
            return false;
        }
        if (!$assertionsDisabled && instanceList.size() <= 0) {
            throw new AssertionError();
        }
        getOptimizableCRF(instanceList);
        getOptimizer(instanceList);
        boolean z = false;
        logger.info("CRF about to train with " + i + " iterations");
        int i2 = 0;
        while (true) {
            if (i2 >= i) {
                break;
            }
            try {
                z = this.opt.optimize(1);
                this.iterationCount++;
                logger.info("CRF finished one iteration of maximizer, i=" + i2);
                runEvaluators();
            } catch (IllegalArgumentException e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                z = true;
            } catch (Exception e2) {
                e2.printStackTrace();
                logger.info("Catching exception; saying converged.");
                z = true;
            }
            if (z) {
                logger.info("CRF training has converged, i=" + i2);
                break;
            }
            i2++;
        }
        return z;
    }

    public boolean train(InstanceList instanceList, int i, double[] dArr) {
        int i2 = 0;
        if (!$assertionsDisabled && dArr.length <= 0) {
            throw new AssertionError();
        }
        boolean z = false;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (!$assertionsDisabled && dArr[i3] > DEFAULT_GAUSSIAN_PRIOR_VARIANCE) {
                throw new AssertionError();
            }
            logger.info("Training on " + dArr[i3] + "% of the data this round.");
            z = dArr[i3] == DEFAULT_GAUSSIAN_PRIOR_VARIANCE ? train(instanceList, i) : train(instanceList.split(new Random(serialVersionUID), new double[]{dArr[i3], DEFAULT_GAUSSIAN_PRIOR_VARIANCE - dArr[i3]})[0], i);
            i2 += i;
        }
        return z;
    }

    public boolean trainWithFeatureInduction(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, TransducerEvaluator transducerEvaluator, int i, int i2, int i3, int i4, double d, boolean z, double[] dArr) {
        return trainWithFeatureInduction(instanceList, instanceList2, instanceList3, transducerEvaluator, i, i2, i3, i4, d, z, dArr, MCMaxEntTrainer.EXP_GAIN);
    }

    public boolean trainWithFeatureInduction(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, TransducerEvaluator transducerEvaluator, int i, int i2, int i3, int i4, double d, boolean z, double[] dArr, String str) {
        int i5 = 0;
        int size = this.crf.outputAlphabet.size();
        this.crf.globalFeatureSelection = instanceList.getFeatureSelection();
        if (this.crf.globalFeatureSelection == null) {
            this.crf.globalFeatureSelection = new FeatureSelection(instanceList.getDataAlphabet());
            instanceList.setFeatureSelection(this.crf.globalFeatureSelection);
        }
        if (instanceList2 != null) {
            instanceList2.setFeatureSelection(this.crf.globalFeatureSelection);
        }
        if (instanceList3 != null) {
            instanceList3.setFeatureSelection(this.crf.globalFeatureSelection);
        }
        for (int i6 = 0; i6 < i3; i6++) {
            logger.info("Feature induction iteration " + i6);
            InstanceList instanceList4 = instanceList;
            if (dArr != null && i6 < dArr.length) {
                logger.info("Training on " + dArr[i6] + "% of the data this round.");
                instanceList4 = instanceList.split(new Random(serialVersionUID), new double[]{dArr[i6], DEFAULT_GAUSSIAN_PRIOR_VARIANCE - dArr[i6]})[0];
                instanceList4.setFeatureSelection(this.crf.globalFeatureSelection);
                logger.info("  which is " + instanceList4.size() + " instances");
            }
            if (i6 != 0) {
                train(instanceList4, i2);
            }
            i5 += i2;
            logger.info("Starting feature induction with " + this.crf.inputAlphabet.size() + " features.");
            InstanceList instanceList5 = new InstanceList(instanceList.getDataAlphabet(), instanceList.getTargetAlphabet());
            instanceList5.setFeatureSelection(this.crf.globalFeatureSelection);
            ArrayList arrayList = new ArrayList();
            InstanceList[][] instanceListArr = new InstanceList[size][size];
            ArrayList[][] arrayListArr = new ArrayList[size][size];
            for (int i7 = 0; i7 < size; i7++) {
                for (int i8 = 0; i8 < size; i8++) {
                    instanceListArr[i7][i8] = new InstanceList(instanceList.getDataAlphabet(), instanceList.getTargetAlphabet());
                    instanceListArr[i7][i8].setFeatureSelection(this.crf.globalFeatureSelection);
                    arrayListArr[i7][i8] = new ArrayList();
                }
            }
            for (int i9 = 0; i9 < instanceList4.size(); i9++) {
                logger.info("instance=" + i9);
                Instance instance = instanceList4.get(i9);
                Sequence sequence = (Sequence) instance.getData();
                Sequence sequence2 = (Sequence) instance.getTarget();
                if (!$assertionsDisabled && sequence.size() != sequence2.size()) {
                    throw new AssertionError();
                }
                SumLattice newSumLattice = this.crf.sumLatticeFactory.newSumLattice(this.crf, sequence, (Sequence) null, (Transducer.Incrementor) null, (LabelAlphabet) instanceList4.getTargetAlphabet());
                int i10 = 0;
                for (int i11 = 0; i11 < sequence2.size(); i11++) {
                    Label labelAtPosition = ((LabelSequence) sequence2).getLabelAtPosition(i11);
                    if (!$assertionsDisabled && labelAtPosition == null) {
                        throw new AssertionError();
                    }
                    LabelVector labelingAtPosition = newSumLattice.getLabelingAtPosition(i11);
                    double value = labelingAtPosition.value(labelAtPosition.getIndex());
                    int bestIndex = labelingAtPosition.getBestIndex();
                    if (value < d) {
                        logger.info("Adding error: instance=" + i9 + " position=" + i11 + " prtrue=" + value + (labelAtPosition == labelingAtPosition.getBestLabel() ? "  " : " *") + " truelabel=" + labelAtPosition + " predlabel=" + labelingAtPosition.getBestLabel() + " fv=" + ((FeatureVector) sequence.get(i11)).toString(true));
                        instanceList5.add(sequence.get(i11), labelAtPosition, null, null);
                        arrayList.add(labelingAtPosition);
                        instanceListArr[i10][bestIndex].add(sequence.get(i11), labelAtPosition, null, null);
                        arrayListArr[i10][bestIndex].add(labelingAtPosition);
                    }
                    i10 = bestIndex;
                }
            }
            logger.info("Error instance list size = " + instanceList5.size());
            if (z) {
                FeatureInducer[][] featureInducerArr = new FeatureInducer[size][size];
                for (int i12 = 0; i12 < size; i12++) {
                    for (int i13 = 0; i13 < size; i13++) {
                        logger.info("Doing feature induction for " + this.crf.outputAlphabet.lookupObject(i12) + " -> " + this.crf.outputAlphabet.lookupObject(i13) + " with " + instanceListArr[i12][i13].size() + " instances");
                        if (instanceListArr[i12][i13].size() < 20) {
                            logger.info("..skipping because only " + instanceListArr[i12][i13].size() + " instances.");
                        } else {
                            int size2 = arrayListArr[i12][i13].size();
                            LabelVector[] labelVectorArr = new LabelVector[size2];
                            for (int i14 = 0; i14 < size2; i14++) {
                                labelVectorArr[i14] = (LabelVector) arrayListArr[i12][i13].get(i14);
                            }
                            RankedFeatureVector.Factory factory = null;
                            if (str.equals(MCMaxEntTrainer.EXP_GAIN)) {
                                factory = new ExpGain.Factory(labelVectorArr, this.gaussianPriorVariance);
                            } else if (str.equals(MCMaxEntTrainer.GRADIENT_GAIN)) {
                                factory = new GradientGain.Factory(labelVectorArr);
                            } else if (str.equals(MCMaxEntTrainer.INFORMATION_GAIN)) {
                                factory = new InfoGain.Factory();
                            }
                            featureInducerArr[i12][i13] = new FeatureInducer(factory, instanceListArr[i12][i13], i4, 2 * i4, 2 * i4);
                            this.crf.featureInducers.add(featureInducerArr[i12][i13]);
                        }
                    }
                }
                for (int i15 = 0; i15 < size; i15++) {
                    for (int i16 = 0; i16 < size; i16++) {
                        logger.info("Adding new induced features for " + this.crf.outputAlphabet.lookupObject(i15) + " -> " + this.crf.outputAlphabet.lookupObject(i16));
                        if (featureInducerArr[i15][i16] == null) {
                            logger.info("...skipping because no features induced.");
                        } else {
                            featureInducerArr[i15][i16].induceFeaturesFor(instanceList, false, false);
                            if (instanceList3 != null) {
                                featureInducerArr[i15][i16].induceFeaturesFor(instanceList3, false, false);
                            }
                        }
                    }
                }
            } else {
                int size3 = arrayList.size();
                LabelVector[] labelVectorArr2 = new LabelVector[size3];
                for (int i17 = 0; i17 < size3; i17++) {
                    labelVectorArr2[i17] = (LabelVector) arrayList.get(i17);
                }
                RankedFeatureVector.Factory factory2 = null;
                if (str.equals(MCMaxEntTrainer.EXP_GAIN)) {
                    factory2 = new ExpGain.Factory(labelVectorArr2, this.gaussianPriorVariance);
                } else if (str.equals(MCMaxEntTrainer.GRADIENT_GAIN)) {
                    factory2 = new GradientGain.Factory(labelVectorArr2);
                } else if (str.equals(MCMaxEntTrainer.INFORMATION_GAIN)) {
                    factory2 = new InfoGain.Factory();
                }
                FeatureInducer featureInducer = new FeatureInducer(factory2, instanceList5, i4, 2 * i4, 2 * i4);
                this.crf.featureInducers.add(featureInducer);
                featureInducer.induceFeaturesFor(instanceList, false, false);
                if (instanceList3 != null) {
                    featureInducer.induceFeaturesFor(instanceList3, false, false);
                }
                logger.info("CRF4 FeatureSelection now includes " + this.crf.globalFeatureSelection.cardinality() + " features");
            }
        }
        return train(instanceList, i - i5);
    }

    public void setUseHyperbolicPrior(boolean z) {
        this.usingHyperbolicPrior = z;
    }

    public void setHyperbolicPriorSlope(double d) {
        this.hyperbolicPriorSlope = d;
    }

    public void setHyperbolicPriorSharpness(double d) {
        this.hyperbolicPriorSharpness = d;
    }

    public double getUseHyperbolicPriorSlope() {
        return this.hyperbolicPriorSlope;
    }

    public double getUseHyperbolicPriorSharpness() {
        return this.hyperbolicPriorSharpness;
    }

    public void setGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
    }

    public double getGaussianPriorVariance() {
        return this.gaussianPriorVariance;
    }

    public void setUseSparseWeights(boolean z) {
        this.useSparseWeights = z;
    }

    public boolean getUseSparseWeights() {
        return this.useSparseWeights;
    }

    public void setUseSomeUnsupportedTrick(boolean z) {
        this.useSomeUnsupportedTrick = z;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(1);
        objectOutputStream.writeBoolean(this.usingHyperbolicPrior);
        objectOutputStream.writeDouble(this.gaussianPriorVariance);
        objectOutputStream.writeDouble(this.hyperbolicPriorSlope);
        objectOutputStream.writeDouble(this.hyperbolicPriorSharpness);
        objectOutputStream.writeInt(this.cachedGradientWeightsStamp);
        objectOutputStream.writeInt(this.cachedValueWeightsStamp);
        objectOutputStream.writeInt(this.cachedWeightsStructureStamp);
        objectOutputStream.writeBoolean(this.printGradient);
        objectOutputStream.writeBoolean(this.useSparseWeights);
        throw new IllegalStateException("Implementation not yet complete.");
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.readInt();
        this.usingHyperbolicPrior = objectInputStream.readBoolean();
        this.gaussianPriorVariance = objectInputStream.readDouble();
        this.hyperbolicPriorSlope = objectInputStream.readDouble();
        this.hyperbolicPriorSharpness = objectInputStream.readDouble();
        this.printGradient = objectInputStream.readBoolean();
        this.useSparseWeights = objectInputStream.readBoolean();
        throw new IllegalStateException("Implementation not yet complete.");
    }

    static {
        $assertionsDisabled = !CRFTrainerByLabelLikelihood.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(CRFTrainerByLabelLikelihood.class.getName());
    }
}
