package cc.mallet.grmm.learning;

import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.CachingOptimizable;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.types.SparseVector;
import cc.mallet.util.FileUtils;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import cc.mallet.util.Timing;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import org.springframework.beans.PropertyAccessor;

/* loaded from: input_file:cc/mallet/grmm/learning/PwplACRFTrainer.class */
public class PwplACRFTrainer extends DefaultAcrfTrainer {
    private static final Logger logger = MalletLogger.getLogger(PwplACRFTrainer.class.getName());
    public static boolean printGradient = false;
    public static final int NO_WRONG_WRONG = 0;
    public static final int CONDITION_WW = 1;
    private int wrongWrongType = 0;
    private int wrongWrongIter = 10;
    private double wrongWrongThreshold = 0.1d;
    private File outputPrefix = new File(".");

    /* loaded from: input_file:cc/mallet/grmm/learning/PwplACRFTrainer$Maxable.class */
    public class Maxable extends CachingOptimizable.ByGradient {
        private ACRF acrf;
        InstanceList trainData;
        private ACRF.Template[] templates;
        private int numParameters;
        private static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 10.0d;
        SparseVector[][] constraints;
        SparseVector[][] expectations;
        SparseVector[] defaultConstraints;
        SparseVector[] defaultExpectations;
        private List[] allWrongWrongs;
        protected BitSet infiniteValues = null;
        private double gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
        private int numCvgaCalls = 0;
        private long timePerCvgaCall = 0;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:cc/mallet/grmm/learning/PwplACRFTrainer$Maxable$WrongWrong.class */
        public class WrongWrong {
            int varIdx;
            int vsIdx;
            int assnIdx;

            public WrongWrong(ACRF.UnrolledGraph unrolledGraph, VarSet varSet, Variable variable, int i) {
                this.varIdx = unrolledGraph.getIndex(variable);
                this.vsIdx = unrolledGraph.getIndex(varSet);
                this.assnIdx = i;
            }

            public ACRF.UnrolledVarSet findVarSet(ACRF.UnrolledGraph unrolledGraph) {
                return unrolledGraph.getUnrolledVarSet(this.vsIdx);
            }

            public Variable findVariable(ACRF.UnrolledGraph unrolledGraph) {
                return unrolledGraph.get(this.varIdx);
            }
        }

        public double getGaussianPriorVariance() {
            return this.gaussianPriorVariance;
        }

        public void setGaussianPriorVariance(double d) {
            this.gaussianPriorVariance = d;
        }

        private void initWeights(InstanceList instanceList) {
            for (int i = 0; i < this.templates.length; i++) {
                this.numParameters += this.templates[i].initWeights(instanceList);
            }
        }

        /* JADX WARN: Type inference failed for: r1v14, types: [cc.mallet.types.SparseVector[], cc.mallet.types.SparseVector[][]] */
        /* JADX WARN: Type inference failed for: r1v18, types: [cc.mallet.types.SparseVector[], cc.mallet.types.SparseVector[][]] */
        private void initConstraintsExpectations() {
            this.defaultConstraints = new SparseVector[this.templates.length];
            this.defaultExpectations = new SparseVector[this.templates.length];
            for (int i = 0; i < this.templates.length; i++) {
                SparseVector defaultWeights = this.templates[i].getDefaultWeights();
                this.defaultConstraints[i] = (SparseVector) defaultWeights.cloneMatrixZeroed();
                this.defaultExpectations[i] = (SparseVector) defaultWeights.cloneMatrixZeroed();
            }
            this.constraints = new SparseVector[this.templates.length];
            this.expectations = new SparseVector[this.templates.length];
            for (int i2 = 0; i2 < this.templates.length; i2++) {
                SparseVector[] weights = this.templates[i2].getWeights();
                this.constraints[i2] = new SparseVector[weights.length];
                this.expectations[i2] = new SparseVector[weights.length];
                for (int i3 = 0; i3 < weights.length; i3++) {
                    this.constraints[i2][i3] = (SparseVector) weights[i3].cloneMatrixZeroed();
                    this.expectations[i2][i3] = (SparseVector) weights[i3].cloneMatrixZeroed();
                }
            }
        }

        void resetProfilingForCall() {
            this.numCvgaCalls = 0;
            this.timePerCvgaCall = 0L;
        }

        void resetExpectations() {
            for (int i = 0; i < this.expectations.length; i++) {
                this.defaultExpectations[i].setAll(0.0d);
                for (int i2 = 0; i2 < this.expectations[i].length; i2++) {
                    this.expectations[i][i2].setAll(0.0d);
                }
            }
        }

        void resetConstraints() {
            for (int i = 0; i < this.constraints.length; i++) {
                this.defaultConstraints[i].setAll(0.0d);
                for (int i2 = 0; i2 < this.constraints[i].length; i2++) {
                    this.constraints[i][i2].setAll(0.0d);
                }
            }
        }

        protected Maxable(ACRF acrf, InstanceList instanceList) {
            PwplACRFTrainer.logger.finest("Initializing OptimizableACRF.");
            this.acrf = acrf;
            this.templates = acrf.getTemplates();
            this.trainData = instanceList;
            initWeights(this.trainData);
            initConstraintsExpectations();
            int size = this.trainData.size();
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            PwplACRFTrainer.logger.info("Number of training instances = " + size);
            PwplACRFTrainer.logger.info("Number of parameters = " + this.numParameters);
            describePrior();
            PwplACRFTrainer.logger.fine("Computing constraints");
            collectConstraints(this.trainData);
        }

        private void describePrior() {
            PwplACRFTrainer.logger.info("Using gaussian prior with variance " + this.gaussianPriorVariance);
        }

        @Override // cc.mallet.optimize.Optimizable
        public int getNumParameters() {
            return this.numParameters;
        }

        @Override // cc.mallet.optimize.Optimizable
        public void getParameters(double[] dArr) {
            if (dArr.length != this.numParameters) {
                throw new IllegalArgumentException("Argument is not of the  correct dimensions");
            }
            int i = 0;
            for (int i2 = 0; i2 < this.templates.length; i2++) {
                double[] values = this.templates[i2].getDefaultWeights().getValues();
                System.arraycopy(values, 0, dArr, i, values.length);
                i += values.length;
            }
            for (int i3 = 0; i3 < this.templates.length; i3++) {
                for (SparseVector sparseVector : this.templates[i3].getWeights()) {
                    double[] values2 = sparseVector.getValues();
                    System.arraycopy(values2, 0, dArr, i, values2.length);
                    i += values2.length;
                }
            }
        }

        @Override // cc.mallet.grmm.util.CachingOptimizable.Base
        protected void setParametersInternal(double[] dArr) {
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            int i = 0;
            for (int i2 = 0; i2 < this.templates.length; i2++) {
                double[] values = this.templates[i2].getDefaultWeights().getValues();
                System.arraycopy(dArr, i, values, 0, values.length);
                i += values.length;
            }
            for (int i3 = 0; i3 < this.templates.length; i3++) {
                for (SparseVector sparseVector : this.templates[i3].getWeights()) {
                    double[] values2 = sparseVector.getValues();
                    System.arraycopy(dArr, i, values2, 0, values2.length);
                    i += values2.length;
                }
            }
        }

        public SparseVector[] getExpectations(int i) {
            return this.expectations[i];
        }

        public SparseVector[] getConstraints(int i) {
            return this.constraints[i];
        }

        public void printParameters() {
            double[] dArr = new double[this.numParameters];
            getParameters(dArr);
            for (double d : dArr) {
                System.out.print(String.valueOf(d) + "\t");
            }
            System.out.println();
        }

        @Override // cc.mallet.grmm.util.CachingOptimizable.ByGradient
        protected double computeValue() {
            double d = 0.0d;
            int size = this.trainData.size();
            long currentTimeMillis = System.currentTimeMillis();
            long j = 0;
            resetProfilingForCall();
            boolean z = false;
            if (this.infiniteValues == null) {
                this.infiniteValues = new BitSet();
                z = true;
            }
            resetExpectations();
            for (int i = 0; i < size; i++) {
                Instance instance = this.trainData.get(i);
                long currentTimeMillis2 = System.currentTimeMillis();
                ACRF.UnrolledGraph unrollStructureOnly = this.acrf.unrollStructureOnly(instance);
                j += System.currentTimeMillis() - currentTimeMillis2;
                double collectExpectationsAndValue = collectExpectationsAndValue(unrollStructureOnly, unrollStructureOnly.getAssignment(), i);
                if (!Double.isInfinite(collectExpectationsAndValue)) {
                    if (Double.isNaN(collectExpectationsAndValue)) {
                        System.out.println("NaN on instance " + i + " : " + instance.getName());
                        printDebugInfo(unrollStructureOnly);
                        PwplACRFTrainer.logger.warning("Value is NaN in ACRF.getValue() Instance " + i + " : returning -infinity... ");
                        return Double.NEGATIVE_INFINITY;
                    }
                    d += collectExpectationsAndValue;
                } else if (z) {
                    PwplACRFTrainer.logger.warning("Instance " + instance.getName() + " has infinite value; skipping.");
                    this.infiniteValues.set(i);
                } else if (!this.infiniteValues.get(i)) {
                    PwplACRFTrainer.logger.warning("Infinite value on instance " + instance.getName() + "returning -infinity");
                    return Double.NEGATIVE_INFINITY;
                }
            }
            double d2 = 2.0d * this.gaussianPriorVariance;
            for (int i2 = 0; i2 < this.templates.length; i2++) {
                SparseVector[] weights = this.templates[i2].getWeights();
                for (int i3 = 0; i3 < weights.length; i3++) {
                    for (int i4 = 0; i4 < weights[i3].numLocations(); i4++) {
                        double valueAtLocation = weights[i3].valueAtLocation(i4);
                        if (weightValid(valueAtLocation, i2, i3)) {
                            d += ((-valueAtLocation) * valueAtLocation) / d2;
                        }
                    }
                }
            }
            PwplACRFTrainer.logger.info("ACRF Inference time (ms) = " + (System.currentTimeMillis() - currentTimeMillis));
            PwplACRFTrainer.logger.info("ACRF unroll time (ms) = " + j);
            PwplACRFTrainer.logger.info("getValue (loglikelihood) = " + d);
            PwplACRFTrainer.logger.info("Number cVGA calls = " + this.numCvgaCalls);
            PwplACRFTrainer.logger.info("Total cVGA time (ms) = " + this.timePerCvgaCall);
            return d;
        }

        @Override // cc.mallet.grmm.util.CachingOptimizable.ByGradient
        protected void computeValueGradient(double[] dArr) {
            double d;
            int i = 0;
            for (int i2 = 0; i2 < this.templates.length; i2++) {
                SparseVector defaultWeights = this.templates[i2].getDefaultWeights();
                SparseVector sparseVector = this.defaultConstraints[i2];
                SparseVector sparseVector2 = this.defaultExpectations[i2];
                for (int i3 = 0; i3 < defaultWeights.numLocations(); i3++) {
                    double valueAtLocation = defaultWeights.valueAtLocation(i3);
                    double valueAtLocation2 = sparseVector.valueAtLocation(i3);
                    double valueAtLocation3 = sparseVector2.valueAtLocation(i3);
                    if (PwplACRFTrainer.printGradient) {
                        System.out.println(" gradient [" + i + "] = " + valueAtLocation2 + " (ctr) - " + valueAtLocation3 + " (exp) - " + (valueAtLocation / this.gaussianPriorVariance) + " (reg)  [feature=DEFAULT]");
                    }
                    int i4 = i;
                    i++;
                    dArr[i4] = (valueAtLocation2 - valueAtLocation3) - (valueAtLocation / this.gaussianPriorVariance);
                }
            }
            for (int i5 = 0; i5 < this.templates.length; i5++) {
                SparseVector[] weights = this.templates[i5].getWeights();
                for (int i6 = 0; i6 < weights.length; i6++) {
                    SparseVector sparseVector3 = weights[i6];
                    SparseVector sparseVector4 = this.constraints[i5][i6];
                    SparseVector sparseVector5 = this.expectations[i5][i6];
                    for (int i7 = 0; i7 < sparseVector3.numLocations(); i7++) {
                        double valueAtLocation4 = sparseVector3.valueAtLocation(i7);
                        double valueAtLocation5 = sparseVector4.valueAtLocation(i7);
                        double valueAtLocation6 = sparseVector5.valueAtLocation(i7);
                        if (Double.isInfinite(valueAtLocation4)) {
                            PwplACRFTrainer.logger.warning("Infinite weight for node index " + i6 + " feature " + this.acrf.getInputAlphabet().lookupObject(i7));
                            d = 0.0d;
                        } else {
                            d = (valueAtLocation5 - (valueAtLocation4 / this.gaussianPriorVariance)) - valueAtLocation6;
                        }
                        if (PwplACRFTrainer.printGradient) {
                            System.out.println(" gradient [" + i + "] = " + valueAtLocation5 + " (ctr) - " + valueAtLocation6 + " (exp) - " + (valueAtLocation4 / this.gaussianPriorVariance) + " (reg)  [feature=" + this.acrf.getInputAlphabet().lookupObject(sparseVector3.indexAtLocation(i7)) + PropertyAccessor.PROPERTY_KEY_SUFFIX);
                        }
                        int i8 = i;
                        i++;
                        dArr[i8] = d;
                    }
                }
            }
        }

        private double collectExpectationsAndValue(ACRF.UnrolledGraph unrolledGraph, Assignment assignment, int i) {
            double d = 0.0d;
            Iterator unrolledVarSetIterator = unrolledGraph.unrolledVarSetIterator();
            while (unrolledVarSetIterator.hasNext()) {
                ACRF.UnrolledVarSet unrolledVarSet = (ACRF.UnrolledVarSet) unrolledVarSetIterator.next();
                if (unrolledVarSet.getTemplate().index != -1) {
                    for (int i2 = 0; i2 < unrolledVarSet.size(); i2++) {
                        d += computeValueGradientForAssn(assignment, unrolledVarSet, unrolledVarSet.get(i2));
                    }
                }
            }
            switch (PwplACRFTrainer.this.wrongWrongType) {
                case 0:
                    break;
                case 1:
                    d += addConditionalWW(unrolledGraph, i);
                    break;
                default:
                    throw new IllegalStateException();
            }
            return d;
        }

        private double addConditionalWW(ACRF.UnrolledGraph unrolledGraph, int i) {
            double d = 0.0d;
            if (this.allWrongWrongs != null) {
                for (WrongWrong wrongWrong : this.allWrongWrongs[i]) {
                    Variable findVariable = wrongWrong.findVariable(unrolledGraph);
                    ACRF.UnrolledVarSet findVarSet = wrongWrong.findVarSet(unrolledGraph);
                    d += computeValueGradientForAssn(Assignment.makeFromSingleIndex(findVarSet, wrongWrong.assnIdx), findVarSet, findVariable);
                }
            }
            return d;
        }

        private double computeValueGradientForAssn(Assignment assignment, ACRF.UnrolledVarSet unrolledVarSet, Variable variable) {
            this.numCvgaCalls++;
            Timing timing = new Timing();
            ACRF.Template template = unrolledVarSet.getTemplate();
            int i = template.index;
            Assignment restriction = Assignment.restriction(assignment, unrolledVarSet);
            int numOutcomes = variable.getNumOutcomes();
            double[] dArr = new double[numOutcomes];
            int[] iArr = new int[numOutcomes];
            for (int i2 = 0; i2 < numOutcomes; i2++) {
                restriction.setValue(variable, i2);
                dArr[i2] = computeLogFactorValue(restriction, template, unrolledVarSet.getFv());
                iArr[i2] = restriction.singleIndex();
            }
            double sumLogProb = Maths.sumLogProb(dArr);
            for (int i3 = 0; i3 < numOutcomes; i3++) {
                double exp = Math.exp(dArr[i3] - sumLogProb);
                int i4 = iArr[i3];
                this.expectations[i][i4].plusEqualsSparse(unrolledVarSet.getFv(), exp);
                if (this.defaultExpectations[i].location(i4) != -1) {
                    this.defaultExpectations[i].incrementValue(i4, exp);
                }
            }
            int i5 = assignment.get(variable);
            this.timePerCvgaCall += timing.elapsedTime();
            return dArr[i5] - sumLogProb;
        }

        private double computeLogFactorValue(Assignment assignment, ACRF.Template template, FeatureVector featureVector) {
            SparseVector[] weights = template.getWeights();
            int singleIndex = assignment.singleIndex();
            return weights[singleIndex].dotProduct((SparseVector) featureVector) + template.getDefaultWeight(singleIndex);
        }

        public void collectConstraints(InstanceList instanceList) {
            for (int i = 0; i < instanceList.size(); i++) {
                PwplACRFTrainer.logger.finest("*** Collecting constraints for instance " + i);
                ACRF.UnrolledGraph unrolledGraph = new ACRF.UnrolledGraph(instanceList.get(i), this.templates, null, false);
                Iterator unrolledVarSetIterator = unrolledGraph.unrolledVarSetIterator();
                while (unrolledVarSetIterator.hasNext()) {
                    ACRF.UnrolledVarSet unrolledVarSet = (ACRF.UnrolledVarSet) unrolledVarSetIterator.next();
                    int i2 = unrolledVarSet.getTemplate().index;
                    if (i2 != -1) {
                        int lookupAssignmentNumber = unrolledVarSet.lookupAssignmentNumber();
                        this.constraints[i2][lookupAssignmentNumber].plusEqualsSparse(unrolledVarSet.getFv(), unrolledVarSet.size());
                        if (this.defaultConstraints[i2].location(lookupAssignmentNumber) != -1) {
                            this.defaultConstraints[i2].incrementValue(lookupAssignmentNumber, unrolledVarSet.size());
                        }
                    }
                }
                if (this.allWrongWrongs != null) {
                    for (WrongWrong wrongWrong : this.allWrongWrongs[i]) {
                        ACRF.UnrolledVarSet findVarSet = wrongWrong.findVarSet(unrolledGraph);
                        int i3 = findVarSet.getTemplate().index;
                        int i4 = wrongWrong.assnIdx;
                        this.constraints[i3][i4].plusEqualsSparse(findVarSet.getFv(), 1.0d);
                        if (this.defaultConstraints[i3].location(i4) != -1) {
                            this.defaultConstraints[i3].incrementValue(i4, 1.0d);
                        }
                    }
                }
            }
        }

        void dumpGradientToFile(String str) {
            try {
                double[] dArr = new double[getNumParameters()];
                getValueGradient(dArr);
                PrintStream printStream = new PrintStream(new FileOutputStream(str));
                for (int i = 0; i < this.numParameters; i++) {
                    printStream.println(dArr[i]);
                }
                printStream.close();
            } catch (IOException e) {
                System.err.println("Could not open output file.");
                e.printStackTrace();
            }
        }

        void dumpDefaults() {
            System.out.println("Default constraints");
            for (int i = 0; i < this.defaultConstraints.length; i++) {
                System.out.println("Template " + i);
                this.defaultConstraints[i].print();
            }
            System.out.println("Default expectations");
            for (int i2 = 0; i2 < this.defaultExpectations.length; i2++) {
                System.out.println("Template " + i2);
                this.defaultExpectations[i2].print();
            }
        }

        void printDebugInfo(ACRF.UnrolledGraph unrolledGraph) {
            this.acrf.print(System.err);
            Assignment assignment = unrolledGraph.getAssignment();
            Iterator unrolledVarSetIterator = unrolledGraph.unrolledVarSetIterator();
            while (unrolledVarSetIterator.hasNext()) {
                ACRF.UnrolledVarSet unrolledVarSet = (ACRF.UnrolledVarSet) unrolledVarSetIterator.next();
                System.out.println("Clique " + unrolledVarSet);
                dumpAssnForClique(assignment, unrolledVarSet);
                Factor factorOf = unrolledGraph.factorOf((VarSet) unrolledVarSet);
                System.out.println("Value = " + factorOf.value(assignment));
                System.out.println(factorOf);
            }
        }

        void dumpAssnForClique(Assignment assignment, ACRF.UnrolledVarSet unrolledVarSet) {
            Iterator it = unrolledVarSet.iterator();
            while (it.hasNext()) {
                Variable variable = (Variable) it.next();
                System.out.println(variable + " ==> " + assignment.getObject(variable) + "  (" + assignment.get(variable) + ")");
            }
        }

        private boolean weightValid(double d, int i, int i2) {
            if (Double.isInfinite(d)) {
                PwplACRFTrainer.logger.warning("Weight is infinite for clique " + i + "assignment " + i2);
                return false;
            }
            if (!Double.isNaN(d)) {
                return true;
            }
            PwplACRFTrainer.logger.warning("Weight is Nan for clique " + i + "assignment " + i2);
            return false;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addWrongWrong(InstanceList instanceList) {
            this.allWrongWrongs = new List[instanceList.size()];
            int i = 0;
            for (int i2 = 0; i2 < instanceList.size(); i2++) {
                this.allWrongWrongs[i2] = new ArrayList();
                int i3 = 0;
                Instance instance = instanceList.get(i2);
                ACRF.UnrolledGraph unroll = this.acrf.unroll(instance);
                if (unroll.factors().size() == 0) {
                    System.err.println("WARNING: FactorGraph for instance " + instance.getName() + " : no factors.");
                } else {
                    Inferencer inferencer = this.acrf.getInferencer();
                    inferencer.computeMarginals(unroll);
                    Assignment assignment = unroll.getAssignment();
                    Iterator unrolledVarSetIterator = unroll.unrolledVarSetIterator();
                    while (unrolledVarSetIterator.hasNext()) {
                        ACRF.UnrolledVarSet unrolledVarSet = (ACRF.UnrolledVarSet) unrolledVarSetIterator.next();
                        Factor lookupMarginal = inferencer.lookupMarginal(unrolledVarSet);
                        AssignmentIterator assignmentIterator = unrolledVarSet.assignmentIterator();
                        while (assignmentIterator.hasNext()) {
                            if (lookupMarginal.value(assignmentIterator) > PwplACRFTrainer.this.wrongWrongThreshold) {
                                Assignment assignment2 = assignmentIterator.assignment();
                                for (int i4 = 0; i4 < unrolledVarSet.size(); i4++) {
                                    Variable variable = unrolledVarSet.get(i4);
                                    if (isWrong2RightAssn(assignment, assignment2, variable)) {
                                        this.allWrongWrongs[i2].add(new WrongWrong(unroll, unrolledVarSet, variable, assignment2.singleIndex()));
                                        i3++;
                                    }
                                }
                            }
                            assignmentIterator.advance();
                        }
                    }
                    PwplACRFTrainer.logger.info("WrongWrongs: Instance " + i2 + " : " + instance.getName() + " Num added = " + i3);
                    i += i3;
                }
            }
            resetConstraints();
            collectConstraints(instanceList);
            forceStale();
            PwplACRFTrainer.logger.info("Total timesteps = " + totalTimesteps(instanceList));
            PwplACRFTrainer.logger.info("Total WrongWrongs = " + i);
        }

        private int totalTimesteps(InstanceList instanceList) {
            int i = 0;
            for (int i2 = 0; i2 < instanceList.size(); i2++) {
                i += ((Sequence) instanceList.get(i2).getData()).size();
            }
            return i;
        }

        private boolean isWrong2RightAssn(Assignment assignment, Assignment assignment2, Variable variable) {
            for (Variable variable2 : assignment2.getVars()) {
                if (variable2 != variable && assignment2.get(variable2) != assignment.get(variable2)) {
                    return assignment2.get(variable) == assignment.get(variable);
                }
            }
            return false;
        }

        @Override // cc.mallet.grmm.util.CachingOptimizable.ByGradient, cc.mallet.grmm.util.CachingOptimizable.Base
        public /* bridge */ /* synthetic */ void forceStale() {
            super.forceStale();
        }

        @Override // cc.mallet.grmm.util.CachingOptimizable.ByGradient, cc.mallet.grmm.util.CachingOptimizable.Base, cc.mallet.optimize.Optimizable
        public /* bridge */ /* synthetic */ double getParameter(int i) {
            return super.getParameter(i);
        }

        @Override // cc.mallet.grmm.util.CachingOptimizable.ByGradient, cc.mallet.grmm.util.CachingOptimizable.Base, cc.mallet.optimize.Optimizable
        public /* bridge */ /* synthetic */ void setParameter(int i, double d) {
            super.setParameter(i, d);
        }

        @Override // cc.mallet.grmm.util.CachingOptimizable.ByGradient, cc.mallet.grmm.util.CachingOptimizable.Base, cc.mallet.optimize.Optimizable
        public /* bridge */ /* synthetic */ void setParameters(double[] dArr) {
            super.setParameters(dArr);
        }
    }

    @Override // cc.mallet.grmm.learning.DefaultAcrfTrainer
    public Optimizable.ByGradientValue createOptimizable(ACRF acrf, InstanceList instanceList) {
        return new Maxable(acrf, instanceList);
    }

    public double getWrongWrongThreshold() {
        return this.wrongWrongThreshold;
    }

    public void setWrongWrongThreshold(double d) {
        this.wrongWrongThreshold = d;
    }

    public void setWrongWrongType(int i) {
        this.wrongWrongType = i;
    }

    public void setWrongWrongIter(int i) {
        this.wrongWrongIter = i;
    }

    @Override // cc.mallet.grmm.learning.DefaultAcrfTrainer, cc.mallet.grmm.learning.ACRFTrainer
    public boolean train(ACRF acrf, InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ACRFEvaluator aCRFEvaluator, int i, Optimizable.ByGradientValue byGradientValue) {
        if (this.wrongWrongType == 0) {
            return super.train(acrf, instanceList, instanceList2, instanceList3, aCRFEvaluator, i, byGradientValue);
        }
        logger.info("BiconditionalPiecewiseACRFTrainer: Initial training");
        super.train(acrf, instanceList, instanceList2, instanceList3, aCRFEvaluator, this.wrongWrongIter, byGradientValue);
        FileUtils.writeGzippedObject(new File(this.outputPrefix, "initial-acrf.ser.gz"), acrf);
        logger.info("BiconditionalPiecewiseACRFTrainer: Adding wrong-wrongs");
        ((Maxable) byGradientValue).addWrongWrong(instanceList);
        logger.info("BiconditionalPiecewiseACRFTrainer: Adding wrong-wrongs");
        boolean train = super.train(acrf, instanceList, instanceList2, instanceList3, aCRFEvaluator, i, byGradientValue);
        reportTrainingLikelihood(acrf, instanceList);
        return train;
    }

    public static void reportTrainingLikelihood(ACRF acrf, InstanceList instanceList) {
        double d = 0.0d;
        Inferencer inferencer = acrf.getInferencer();
        for (int i = 0; i < instanceList.size(); i++) {
            ACRF.UnrolledGraph unroll = acrf.unroll(instanceList.get(i));
            inferencer.computeMarginals(unroll);
            double lookupLogJoint = inferencer.lookupLogJoint(unroll.getAssignment());
            d += lookupLogJoint;
            logger.info("...instance " + i + " likelihood = " + lookupLogJoint);
        }
        logger.info("Unregularized joint likelihood = " + d);
    }
}
