package cc.mallet.classify.constraints.ge;

import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gnu.trove.TDoubleArrayList;
import gnu.trove.TIntArrayList;
import gnu.trove.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/classify/constraints/ge/MaxEntRangeL2FLGEConstraints.class */
public class MaxEntRangeL2FLGEConstraints implements MaxEntGEConstraint {
    private boolean useValues;
    private boolean normalize;
    private int numFeatures;
    private int numLabels;
    protected TIntObjectHashMap<MaxEntL2IndGEConstraint> constraints = new TIntObjectHashMap<>();
    protected TIntArrayList indexCache = new TIntArrayList();
    protected TDoubleArrayList valueCache = new TDoubleArrayList();
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:cc/mallet/classify/constraints/ge/MaxEntRangeL2FLGEConstraints$MaxEntL2IndGEConstraint.class */
    public class MaxEntL2IndGEConstraint {
        protected double[] expectation;
        static final /* synthetic */ boolean $assertionsDisabled;
        protected ArrayList<Double> lower = new ArrayList<>();
        protected ArrayList<Double> upper = new ArrayList<>();
        protected ArrayList<Double> weights = new ArrayList<>();
        protected HashMap<Integer, Integer> labelMap = new HashMap<>();
        protected int index = 0;
        protected double count = 0.0d;

        static {
            $assertionsDisabled = !MaxEntRangeL2FLGEConstraints.class.desiredAssertionStatus();
        }

        public MaxEntL2IndGEConstraint() {
        }

        public void add(int i, double d, double d2, double d3) {
            this.lower.add(Double.valueOf(d));
            this.upper.add(Double.valueOf(d2));
            this.weights.add(Double.valueOf(d3));
            this.labelMap.put(Integer.valueOf(i), Integer.valueOf(this.index));
            this.index++;
        }

        public void incrementExpectation(int i, double d) {
            if (this.labelMap.containsKey(Integer.valueOf(i))) {
                int intValue = this.labelMap.get(Integer.valueOf(i)).intValue();
                double[] dArr = this.expectation;
                dArr[intValue] = dArr[intValue] + d;
            }
        }

        public double getValue(int i) {
            if (!this.labelMap.containsKey(Integer.valueOf(i))) {
                return 0.0d;
            }
            int intValue = this.labelMap.get(Integer.valueOf(i)).intValue();
            if (!$assertionsDisabled && this.count == 0.0d) {
                throw new AssertionError();
            }
            double d = MaxEntRangeL2FLGEConstraints.this.normalize ? this.expectation[intValue] / this.count : this.expectation[intValue];
            if (d < this.lower.get(intValue).doubleValue()) {
                return this.weights.get(intValue).doubleValue() * Math.pow(this.lower.get(intValue).doubleValue() - d, 2.0d);
            }
            if (d > this.upper.get(intValue).doubleValue()) {
                return this.weights.get(intValue).doubleValue() * Math.pow(this.upper.get(intValue).doubleValue() - d, 2.0d);
            }
            return 0.0d;
        }

        public int getNumConstrainedLabels() {
            return this.index;
        }

        public double getGradientContribution(int i) {
            if (!this.labelMap.containsKey(Integer.valueOf(i))) {
                return 0.0d;
            }
            int intValue = this.labelMap.get(Integer.valueOf(i)).intValue();
            if (!$assertionsDisabled && this.count == 0.0d) {
                throw new AssertionError();
            }
            if (MaxEntRangeL2FLGEConstraints.this.normalize) {
                double d = this.expectation[intValue] / this.count;
                if (d < this.lower.get(intValue).doubleValue()) {
                    return 2.0d * this.weights.get(intValue).doubleValue() * ((this.lower.get(intValue).doubleValue() / this.count) - (this.expectation[intValue] / (this.count * this.count)));
                }
                if (d > this.upper.get(intValue).doubleValue()) {
                    return 2.0d * this.weights.get(intValue).doubleValue() * ((this.upper.get(intValue).doubleValue() / this.count) - (this.expectation[intValue] / (this.count * this.count)));
                }
                return 0.0d;
            }
            double d2 = this.expectation[intValue];
            if (d2 < this.lower.get(intValue).doubleValue()) {
                return 2.0d * this.weights.get(intValue).doubleValue() * (this.lower.get(intValue).doubleValue() - this.expectation[intValue]);
            }
            if (d2 > this.upper.get(intValue).doubleValue()) {
                return 2.0d * this.weights.get(intValue).doubleValue() * (this.upper.get(intValue).doubleValue() - this.expectation[intValue]);
            }
            return 0.0d;
        }
    }

    static {
        $assertionsDisabled = !MaxEntRangeL2FLGEConstraints.class.desiredAssertionStatus();
    }

    public MaxEntRangeL2FLGEConstraints(int i, int i2, boolean z, boolean z2) {
        this.numFeatures = i;
        this.numLabels = i2;
        this.useValues = z;
        this.normalize = z2;
    }

    public void addConstraint(int i, int i2, double d, double d2, double d3) {
        if (!this.constraints.containsKey(i)) {
            this.constraints.put(i, new MaxEntL2IndGEConstraint());
        }
        this.constraints.get(i).add(i2, d, d2, d3);
    }

    @Override // cc.mallet.classify.constraints.ge.MaxEntGEConstraint
    public BitSet preProcess(InstanceList instanceList) {
        int i = 0;
        BitSet bitSet = new BitSet(instanceList.size());
        Iterator<Instance> it = instanceList.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            double instanceWeight = instanceList.getInstanceWeight(next);
            FeatureVector featureVector = (FeatureVector) next.getData();
            for (int i2 = 0; i2 < featureVector.numLocations(); i2++) {
                int indexAtLocation = featureVector.indexAtLocation(i2);
                if (this.constraints.containsKey(indexAtLocation)) {
                    if (this.useValues) {
                        this.constraints.get(indexAtLocation).count += instanceWeight * featureVector.valueAtLocation(i2);
                    } else {
                        this.constraints.get(indexAtLocation).count += instanceWeight;
                    }
                    bitSet.set(i);
                }
            }
            i++;
            if (this.constraints.containsKey(this.numFeatures)) {
                bitSet.set(i);
                this.constraints.get(this.numFeatures).count += instanceWeight;
            }
        }
        return bitSet;
    }

    @Override // cc.mallet.classify.constraints.ge.MaxEntGEConstraint
    public void preProcess(FeatureVector featureVector) {
        this.indexCache.resetQuick();
        if (this.useValues) {
            this.valueCache.resetQuick();
        }
        for (int i = 0; i < featureVector.numLocations(); i++) {
            int indexAtLocation = featureVector.indexAtLocation(i);
            if (this.constraints.containsKey(indexAtLocation)) {
                this.indexCache.add(indexAtLocation);
                if (this.useValues) {
                    this.valueCache.add(featureVector.valueAtLocation(i));
                }
            }
        }
        if (this.constraints.containsKey(this.numFeatures)) {
            this.indexCache.add(this.numFeatures);
            if (this.useValues) {
                this.valueCache.add(1.0d);
            }
        }
    }

    @Override // cc.mallet.classify.constraints.ge.MaxEntGEConstraint
    public double getCompositeConstraintFeatureValue(FeatureVector featureVector, int i) {
        double d;
        double gradientContribution;
        double d2 = 0.0d;
        for (int i2 = 0; i2 < this.indexCache.size(); i2++) {
            if (this.useValues) {
                d = d2;
                gradientContribution = this.constraints.get(this.indexCache.getQuick(i2)).getGradientContribution(i) * this.valueCache.getQuick(i2);
            } else {
                d = d2;
                gradientContribution = this.constraints.get(this.indexCache.getQuick(i2)).getGradientContribution(i);
            }
            d2 = d + gradientContribution;
        }
        return d2;
    }

    @Override // cc.mallet.classify.constraints.ge.MaxEntGEConstraint
    public void computeExpectations(FeatureVector featureVector, double[] dArr, double d) {
        preProcess(featureVector);
        for (int i = 0; i < this.numLabels; i++) {
            double d2 = d * dArr[i];
            for (int i2 = 0; i2 < this.indexCache.size(); i2++) {
                if (this.useValues) {
                    double[] dArr2 = this.constraints.get(this.indexCache.getQuick(i2)).expectation;
                    int i3 = i;
                    dArr2[i3] = dArr2[i3] + (d2 * this.valueCache.getQuick(i2));
                } else {
                    double[] dArr3 = this.constraints.get(this.indexCache.getQuick(i2)).expectation;
                    int i4 = i;
                    dArr3[i4] = dArr3[i4] + d2;
                }
            }
        }
    }

    @Override // cc.mallet.classify.constraints.ge.MaxEntGEConstraint
    public double getValue() {
        double d = 0.0d;
        for (int i : this.constraints.keys()) {
            MaxEntL2IndGEConstraint maxEntL2IndGEConstraint = this.constraints.get(i);
            if (maxEntL2IndGEConstraint.count > 0.0d) {
                for (int i2 = 0; i2 < this.numLabels; i2++) {
                    d -= maxEntL2IndGEConstraint.getValue(i2);
                }
            }
        }
        if ($assertionsDisabled || !(Double.isNaN(d) || Double.isInfinite(d))) {
            return d;
        }
        throw new AssertionError();
    }

    @Override // cc.mallet.classify.constraints.ge.MaxEntGEConstraint
    public void zeroExpectations() {
        for (int i : this.constraints.keys()) {
            this.constraints.get(i).expectation = new double[this.constraints.get(i).getNumConstrainedLabels()];
        }
    }
}
