package cc.mallet.grmm.learning;

import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.CachingOptimizable;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.SparseVector;
import cc.mallet.util.MalletLogger;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.BitSet;
import java.util.Iterator;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/grmm/learning/PiecewiseACRFTrainer.class */
public class PiecewiseACRFTrainer extends DefaultAcrfTrainer {
    private static final Logger logger = MalletLogger.getLogger(PiecewiseACRFTrainer.class.getName());
    private static final boolean printGradient = false;

    /* loaded from: input_file:cc/mallet/grmm/learning/PiecewiseACRFTrainer$Maxable.class */
    public static class Maxable extends CachingOptimizable.ByGradient implements Serializable {
        private ACRF acrf;
        InstanceList trainData;
        private ACRF.Template[] templates;
        private ACRF.Template[] fixedTmpls;
        private int numParameters;
        private static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 10.0d;
        SparseVector[][] constraints;
        SparseVector[][] expectations;
        SparseVector[] defaultConstraints;
        SparseVector[] defaultExpectations;
        protected BitSet infiniteValues = null;
        private double gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
        int numInBatch = 0;

        public double getGaussianPriorVariance() {
            return this.gaussianPriorVariance;
        }

        public void setGaussianPriorVariance(double d) {
            this.gaussianPriorVariance = d;
        }

        private void initWeights(InstanceList instanceList) {
            for (int i = 0; i < this.templates.length; i++) {
                this.numParameters += this.templates[i].initWeights(instanceList);
            }
        }

        /* JADX WARN: Type inference failed for: r1v14, types: [cc.mallet.types.SparseVector[], cc.mallet.types.SparseVector[][]] */
        /* JADX WARN: Type inference failed for: r1v18, types: [cc.mallet.types.SparseVector[], cc.mallet.types.SparseVector[][]] */
        private void initConstraintsExpectations() {
            this.defaultConstraints = new SparseVector[this.templates.length];
            this.defaultExpectations = new SparseVector[this.templates.length];
            for (int i = 0; i < this.templates.length; i++) {
                SparseVector defaultWeights = this.templates[i].getDefaultWeights();
                this.defaultConstraints[i] = (SparseVector) defaultWeights.cloneMatrixZeroed();
                this.defaultExpectations[i] = (SparseVector) defaultWeights.cloneMatrixZeroed();
            }
            this.constraints = new SparseVector[this.templates.length];
            this.expectations = new SparseVector[this.templates.length];
            for (int i2 = 0; i2 < this.templates.length; i2++) {
                SparseVector[] weights = this.templates[i2].getWeights();
                this.constraints[i2] = new SparseVector[weights.length];
                this.expectations[i2] = new SparseVector[weights.length];
                for (int i3 = 0; i3 < weights.length; i3++) {
                    this.constraints[i2][i3] = (SparseVector) weights[i3].cloneMatrixZeroed();
                    this.expectations[i2][i3] = (SparseVector) weights[i3].cloneMatrixZeroed();
                }
            }
        }

        void resetExpectations() {
            for (int i = 0; i < this.expectations.length; i++) {
                this.defaultExpectations[i].setAll(0.0d);
                for (int i2 = 0; i2 < this.expectations[i].length; i2++) {
                    this.expectations[i][i2].setAll(0.0d);
                }
            }
        }

        void resetConstraints() {
            for (int i = 0; i < this.constraints.length; i++) {
                this.defaultConstraints[i].setAll(0.0d);
                for (int i2 = 0; i2 < this.constraints[i].length; i2++) {
                    this.constraints[i][i2].setAll(0.0d);
                }
            }
        }

        protected Maxable(ACRF acrf, InstanceList instanceList) {
            PiecewiseACRFTrainer.logger.finest("Initializing OptimizableACRF.");
            this.acrf = acrf;
            this.templates = acrf.getTemplates();
            this.fixedTmpls = acrf.getFixedTemplates();
            this.trainData = instanceList;
            initWeights(this.trainData);
            initConstraintsExpectations();
            int size = this.trainData.size();
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            PiecewiseACRFTrainer.logger.info("Number of training instances = " + size);
            PiecewiseACRFTrainer.logger.info("Number of parameters = " + this.numParameters);
            describePrior();
            PiecewiseACRFTrainer.logger.fine("Computing constraints");
            collectConstraints(this.trainData);
        }

        private void describePrior() {
            PiecewiseACRFTrainer.logger.info("Using gaussian prior with variance " + this.gaussianPriorVariance);
        }

        @Override // cc.mallet.optimize.Optimizable
        public int getNumParameters() {
            return this.numParameters;
        }

        @Override // cc.mallet.optimize.Optimizable
        public void getParameters(double[] dArr) {
            if (dArr.length != this.numParameters) {
                throw new IllegalArgumentException("Argument is not of the  correct dimensions");
            }
            int i = 0;
            for (int i2 = 0; i2 < this.templates.length; i2++) {
                double[] values = this.templates[i2].getDefaultWeights().getValues();
                System.arraycopy(values, 0, dArr, i, values.length);
                i += values.length;
            }
            for (int i3 = 0; i3 < this.templates.length; i3++) {
                for (SparseVector sparseVector : this.templates[i3].getWeights()) {
                    double[] values2 = sparseVector.getValues();
                    System.arraycopy(values2, 0, dArr, i, values2.length);
                    i += values2.length;
                }
            }
        }

        @Override // cc.mallet.grmm.util.CachingOptimizable.Base
        protected void setParametersInternal(double[] dArr) {
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            int i = 0;
            for (int i2 = 0; i2 < this.templates.length; i2++) {
                double[] values = this.templates[i2].getDefaultWeights().getValues();
                System.arraycopy(dArr, i, values, 0, values.length);
                i += values.length;
            }
            for (int i3 = 0; i3 < this.templates.length; i3++) {
                for (SparseVector sparseVector : this.templates[i3].getWeights()) {
                    double[] values2 = sparseVector.getValues();
                    System.arraycopy(dArr, i, values2, 0, values2.length);
                    i += values2.length;
                }
            }
        }

        public SparseVector[] getExpectations(int i) {
            return this.expectations[i];
        }

        public SparseVector[] getConstraints(int i) {
            return this.constraints[i];
        }

        public void printParameters() {
            double[] dArr = new double[this.numParameters];
            getParameters(dArr);
            for (double d : dArr) {
                System.out.print(d + "\t");
            }
            System.out.println();
        }

        @Override // cc.mallet.grmm.util.CachingOptimizable.ByGradient
        protected double computeValue() {
            double d = 0.0d;
            int size = this.trainData.size();
            long currentTimeMillis = System.currentTimeMillis();
            if (this.infiniteValues == null) {
                this.infiniteValues = new BitSet();
            }
            resetExpectations();
            for (int i = 0; i < size; i++) {
                d += computeValueForInstance(i);
            }
            double computePrior = d + computePrior();
            PiecewiseACRFTrainer.logger.info("ACRF Inference time (ms) = " + (System.currentTimeMillis() - currentTimeMillis));
            PiecewiseACRFTrainer.logger.info("ACRF unroll time (ms) = 0");
            PiecewiseACRFTrainer.logger.info("getValue (loglikelihood) = " + computePrior);
            return computePrior;
        }

        private double computePrior() {
            double d = 0.0d;
            double d2 = 2.0d * this.gaussianPriorVariance;
            for (int i = 0; i < this.templates.length; i++) {
                SparseVector[] weights = this.templates[i].getWeights();
                for (int i2 = 0; i2 < weights.length; i2++) {
                    for (int i3 = 0; i3 < weights[i2].numLocations(); i3++) {
                        double valueAtLocation = weights[i2].valueAtLocation(i3);
                        if (weightValid(valueAtLocation, i, i2)) {
                            d += ((-valueAtLocation) * valueAtLocation) / d2;
                        }
                    }
                }
            }
            return d;
        }

        private double computeValueForInstance(int i) {
            Instance instance = this.trainData.get(i);
            ACRF.UnrolledGraph unrolledGraph = new ACRF.UnrolledGraph(instance, this.templates, this.fixedTmpls);
            if (unrolledGraph.numVariables() == 0) {
                return 0.0d;
            }
            double collectExpectationsAndValue = collectExpectationsAndValue(unrolledGraph, unrolledGraph.getAssignment());
            if (!Double.isNaN(collectExpectationsAndValue)) {
                return 0.0d + collectExpectationsAndValue;
            }
            System.out.println("NaN on instance " + i + " : " + instance.getName());
            printDebugInfo(unrolledGraph);
            PiecewiseACRFTrainer.logger.warning("Value is NaN in ACRF.getValue() Instance " + i + " : returning -infinity... ");
            return Double.NEGATIVE_INFINITY;
        }

        @Override // cc.mallet.grmm.util.CachingOptimizable.ByGradient
        protected void computeValueGradient(double[] dArr) {
            computeValueGradient(dArr, 1.0d);
        }

        private void computeValueGradient(double[] dArr, double d) {
            double d2;
            int i = 0;
            for (int i2 = 0; i2 < this.templates.length; i2++) {
                SparseVector defaultWeights = this.templates[i2].getDefaultWeights();
                SparseVector sparseVector = this.defaultConstraints[i2];
                SparseVector sparseVector2 = this.defaultExpectations[i2];
                for (int i3 = 0; i3 < defaultWeights.numLocations(); i3++) {
                    int i4 = i;
                    i++;
                    dArr[i4] = (sparseVector.valueAtLocation(i3) - sparseVector2.valueAtLocation(i3)) - (d * (defaultWeights.valueAtLocation(i3) / this.gaussianPriorVariance));
                }
            }
            for (int i5 = 0; i5 < this.templates.length; i5++) {
                SparseVector[] weights = this.templates[i5].getWeights();
                for (int i6 = 0; i6 < weights.length; i6++) {
                    SparseVector sparseVector3 = weights[i6];
                    SparseVector sparseVector4 = this.constraints[i5][i6];
                    SparseVector sparseVector5 = this.expectations[i5][i6];
                    for (int i7 = 0; i7 < sparseVector3.numLocations(); i7++) {
                        double valueAtLocation = sparseVector3.valueAtLocation(i7);
                        double valueAtLocation2 = sparseVector4.valueAtLocation(i7);
                        double valueAtLocation3 = sparseVector5.valueAtLocation(i7);
                        if (Double.isInfinite(valueAtLocation)) {
                            PiecewiseACRFTrainer.logger.warning("Infinite weight for node index " + i6 + " feature " + this.acrf.getInputAlphabet().lookupObject(i7));
                            d2 = 0.0d;
                        } else {
                            d2 = (valueAtLocation2 - (d * (valueAtLocation / this.gaussianPriorVariance))) - valueAtLocation3;
                        }
                        int i8 = i;
                        i++;
                        dArr[i8] = d2;
                    }
                }
            }
        }

        private double collectExpectationsAndValue(ACRF.UnrolledGraph unrolledGraph, Assignment assignment) {
            double d = 0.0d;
            Iterator unrolledVarSetIterator = unrolledGraph.unrolledVarSetIterator();
            while (unrolledVarSetIterator.hasNext()) {
                ACRF.UnrolledVarSet unrolledVarSet = (ACRF.UnrolledVarSet) unrolledVarSetIterator.next();
                int i = unrolledVarSet.getTemplate().index;
                if (i != -1) {
                    Factor factorOf = unrolledGraph.factorOf((VarSet) unrolledVarSet);
                    double log = Math.log(factorOf.sum());
                    AssignmentIterator assignmentIterator = unrolledVarSet.assignmentIterator();
                    int i2 = 0;
                    while (assignmentIterator.hasNext()) {
                        double exp = Math.exp(factorOf.logValue(assignmentIterator) - log);
                        this.expectations[i][i2].plusEqualsSparse(unrolledVarSet.getFv(), exp);
                        if (this.defaultExpectations[i].location(i2) != -1) {
                            this.defaultExpectations[i].incrementValue(i2, exp);
                        }
                        assignmentIterator.advance();
                        i2++;
                    }
                    d += factorOf.logValue(assignment) - log;
                }
            }
            return d;
        }

        public void collectConstraints(InstanceList instanceList) {
            for (int i = 0; i < instanceList.size(); i++) {
                PiecewiseACRFTrainer.logger.finest("*** Collecting constraints for instance " + i);
                collectConstraintsForInstance(instanceList, i);
            }
        }

        private void collectConstraintsForInstance(InstanceList instanceList, int i) {
            Iterator unrolledVarSetIterator = new ACRF.UnrolledGraph(instanceList.get(i), this.templates, null, false).unrolledVarSetIterator();
            while (unrolledVarSetIterator.hasNext()) {
                ACRF.UnrolledVarSet unrolledVarSet = (ACRF.UnrolledVarSet) unrolledVarSetIterator.next();
                int i2 = unrolledVarSet.getTemplate().index;
                if (i2 != -1) {
                    int lookupAssignmentNumber = unrolledVarSet.lookupAssignmentNumber();
                    this.constraints[i2][lookupAssignmentNumber].plusEqualsSparse(unrolledVarSet.getFv());
                    if (this.defaultConstraints[i2].location(lookupAssignmentNumber) != -1) {
                        this.defaultConstraints[i2].incrementValue(lookupAssignmentNumber, 1.0d);
                    }
                }
            }
        }

        void dumpGradientToFile(String str) {
            try {
                double[] dArr = new double[getNumParameters()];
                getValueGradient(dArr);
                PrintStream printStream = new PrintStream(new FileOutputStream(str));
                for (int i = 0; i < this.numParameters; i++) {
                    printStream.println(dArr[i]);
                }
                printStream.close();
            } catch (IOException e) {
                System.err.println("Could not open output file.");
                e.printStackTrace();
            }
        }

        void dumpDefaults() {
            System.out.println("Default constraints");
            for (int i = 0; i < this.defaultConstraints.length; i++) {
                System.out.println("Template " + i);
                this.defaultConstraints[i].print();
            }
            System.out.println("Default expectations");
            for (int i2 = 0; i2 < this.defaultExpectations.length; i2++) {
                System.out.println("Template " + i2);
                this.defaultExpectations[i2].print();
            }
        }

        void printDebugInfo(ACRF.UnrolledGraph unrolledGraph) {
            this.acrf.print(System.err);
            Assignment assignment = unrolledGraph.getAssignment();
            Iterator unrolledVarSetIterator = unrolledGraph.unrolledVarSetIterator();
            while (unrolledVarSetIterator.hasNext()) {
                ACRF.UnrolledVarSet unrolledVarSet = (ACRF.UnrolledVarSet) unrolledVarSetIterator.next();
                System.out.println("Clique " + unrolledVarSet);
                dumpAssnForClique(assignment, unrolledVarSet);
                Factor factorOf = unrolledGraph.factorOf((VarSet) unrolledVarSet);
                System.out.println("Value = " + factorOf.value(assignment));
                System.out.println(factorOf);
            }
        }

        void dumpAssnForClique(Assignment assignment, ACRF.UnrolledVarSet unrolledVarSet) {
            Iterator it = unrolledVarSet.iterator();
            while (it.hasNext()) {
                Variable variable = (Variable) it.next();
                System.out.println(variable + " ==> " + assignment.getObject(variable) + "  (" + assignment.get(variable) + ")");
            }
        }

        private boolean weightValid(double d, int i, int i2) {
            if (Double.isInfinite(d)) {
                PiecewiseACRFTrainer.logger.warning("Weight is infinite for clique " + i + "assignment " + i2);
                return false;
            }
            if (!Double.isNaN(d)) {
                return true;
            }
            PiecewiseACRFTrainer.logger.warning("Weight is Nan for clique " + i + "assignment " + i2);
            return false;
        }

        public double computeValueAndGradient(int i) {
            this.numInBatch++;
            collectConstraintsForInstance(this.trainData, i);
            return computeValueForInstance(i) + (computePrior() / this.trainData.size());
        }

        public int getNumInstances() {
            return this.trainData.size();
        }

        public void getCachedGradient(double[] dArr) {
            computeValueGradient(dArr, this.numInBatch / this.trainData.size());
        }

        public void resetValueGradient() {
            resetExpectations();
            resetConstraints();
        }
    }

    @Override // cc.mallet.grmm.learning.DefaultAcrfTrainer
    public Optimizable.ByGradientValue createOptimizable(ACRF acrf, InstanceList instanceList) {
        return new Maxable(acrf, instanceList);
    }
}
