package cc.mallet.fst.semi_supervised;

import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLattice;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/* loaded from: input_file:cc/mallet/fst/semi_supervised/CRFOptimizableByGE.class */
public class CRFOptimizableByGE implements Optimizable.ByGradientValue {
    private static final int DEFAULT_GPV = 10;
    private CRF crf;
    private ArrayList<GEConstraint> constraints;
    private InstanceList data;
    private int numThreads;
    private double gpv;
    private double weight;
    private int cache;
    private double cachedValue;
    private CRF.Factors cachedGradient;
    private int[][] reverseTrans;
    private int[][] reverseTransIndices;
    private BitSet instancesWithConstraints;
    private ThreadPoolExecutor executor;
    static final /* synthetic */ boolean $assertionsDisabled;

    public CRFOptimizableByGE(CRF crf, ArrayList<GEConstraint> arrayList, InstanceList instanceList, StateLabelMap stateLabelMap, int i) {
        this(crf, arrayList, instanceList, stateLabelMap, i, 1.0d);
    }

    public CRFOptimizableByGE(CRF crf, ArrayList<GEConstraint> arrayList, InstanceList instanceList, StateLabelMap stateLabelMap, int i, double d) {
        this.crf = crf;
        this.constraints = arrayList;
        this.cache = Integer.MAX_VALUE;
        this.cachedValue = Double.NaN;
        this.cachedGradient = new CRF.Factors(crf);
        this.data = instanceList;
        this.numThreads = i;
        this.weight = d;
        this.instancesWithConstraints = new BitSet(instanceList.size());
        Iterator<GEConstraint> it = arrayList.iterator();
        while (it.hasNext()) {
            GEConstraint next = it.next();
            next.setStateLabelMap(stateLabelMap);
            this.instancesWithConstraints.or(next.preProcess(instanceList));
        }
        this.gpv = 10.0d;
        if (i > 1) {
            this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(i);
        }
        createReverseTransitionMatrices(crf);
    }

    /* JADX WARN: Type inference failed for: r1v4, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [int[], int[][]] */
    public void createReverseTransitionMatrices(CRF crf) {
        int[] iArr = new int[crf.numStates()];
        for (int i = 0; i < crf.numStates(); i++) {
            CRF.State state = (CRF.State) crf.getState(i);
            for (int i2 = 0; i2 < state.numDestinations(); i2++) {
                int index = state.getDestinationState(i2).getIndex();
                iArr[index] = iArr[index] + 1;
            }
        }
        this.reverseTrans = new int[crf.numStates()];
        this.reverseTransIndices = new int[crf.numStates()];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            this.reverseTrans[i3] = new int[iArr[i3]];
            this.reverseTransIndices[i3] = new int[iArr[i3]];
        }
        int[] iArr2 = new int[crf.numStates()];
        for (int i4 = 0; i4 < crf.numStates(); i4++) {
            CRF.State state2 = (CRF.State) crf.getState(i4);
            for (int i5 = 0; i5 < state2.numDestinations(); i5++) {
                int index2 = state2.getDestinationState(i5).getIndex();
                this.reverseTrans[index2][iArr2[index2]] = i4;
                this.reverseTransIndices[index2][iArr2[index2]] = i5;
                iArr2[index2] = iArr2[index2] + 1;
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.crf.getNumParameters();
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        this.crf.getParameters().getParameters(dArr);
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.crf.getParameters().getParameter(i);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        this.crf.getParameters().setParameters(dArr);
        this.crf.weightsValueChanged();
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.crf.getParameters().setParameter(i, d);
        this.crf.weightsValueChanged();
    }

    public void cacheValueAndGradient() {
        ArrayList<SumLattice> arrayList = new ArrayList<>();
        if (this.numThreads == 1) {
            for (int i = 0; i < this.data.size(); i++) {
                if (this.instancesWithConstraints.get(i)) {
                    arrayList.add(new SumLatticeDefault((Transducer) this.crf, (Sequence) this.data.get(i).getData(), (Sequence) null, (Transducer.Incrementor) null, true));
                } else {
                    arrayList.add(null);
                }
            }
        } else {
            ArrayList arrayList2 = new ArrayList();
            if (this.data.size() < this.numThreads) {
                this.numThreads = this.data.size();
            }
            int size = this.data.size() / this.numThreads;
            int i2 = 0;
            int i3 = size;
            int i4 = 0;
            while (i4 < this.numThreads) {
                arrayList2.add(new SumLatticeTask(this.crf, this.data, this.instancesWithConstraints, i2, i3));
                i2 += size;
                i3 = i4 == this.numThreads - 2 ? this.data.size() : i3 + size;
                i4++;
            }
            try {
                this.executor.invokeAll(arrayList2);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            Iterator it = arrayList2.iterator();
            while (it.hasNext()) {
                arrayList.addAll(((SumLatticeTask) ((Callable) it.next())).getLattices());
            }
            if (!$assertionsDisabled && arrayList.size() != this.data.size()) {
                throw new AssertionError(arrayList.size() + " " + this.data.size());
            }
        }
        System.err.println("Done computing lattices.");
        Iterator<GEConstraint> it2 = this.constraints.iterator();
        while (it2.hasNext()) {
            GEConstraint next = it2.next();
            next.zeroExpectations();
            next.computeExpectations(arrayList);
        }
        System.err.println("Done computing expectations.");
        this.cachedValue = 0.0d;
        Iterator<GEConstraint> it3 = this.constraints.iterator();
        while (it3.hasNext()) {
            this.cachedValue += it3.next().getValue();
        }
        this.cachedGradient.zero();
        if (this.numThreads == 1) {
            for (int i5 = 0; i5 < this.data.size(); i5++) {
                if (this.instancesWithConstraints.get(i5)) {
                    SumLattice sumLattice = arrayList.get(i5);
                    new GELattice((FeatureVectorSequence) this.data.get(i5).getData(), sumLattice.getGammas(), sumLattice.getXis(), this.crf, this.reverseTrans, this.reverseTransIndices, this.cachedGradient, this.constraints, false);
                }
            }
        } else {
            ArrayList arrayList3 = new ArrayList();
            if (this.data.size() < this.numThreads) {
                this.numThreads = this.data.size();
            }
            int size2 = this.data.size() / this.numThreads;
            int i6 = 0;
            int i7 = size2;
            int i8 = 0;
            while (i8 < this.numThreads) {
                ArrayList arrayList4 = new ArrayList();
                Iterator<GEConstraint> it4 = this.constraints.iterator();
                while (it4.hasNext()) {
                    arrayList4.add(it4.next().copy());
                }
                arrayList3.add(new GELatticeTask(this.crf, this.data, arrayList, arrayList4, this.instancesWithConstraints, this.reverseTrans, this.reverseTransIndices, i6, i7));
                i6 += size2;
                i7 = i8 == this.numThreads - 2 ? this.data.size() : i7 + size2;
                i8++;
            }
            try {
                this.executor.invokeAll(arrayList3);
            } catch (InterruptedException e2) {
                e2.printStackTrace();
            }
            Iterator it5 = arrayList3.iterator();
            while (it5.hasNext()) {
                this.cachedGradient.plusEquals(((GELatticeTask) ((Callable) it5.next())).getGradient(), 1.0d);
            }
        }
        System.err.println("Done computing gradient.");
        this.cachedValue += this.crf.getParameters().gaussianPrior(this.gpv);
        this.cachedGradient.plusEqualsGaussianPriorGradient(this.crf.getParameters(), this.gpv);
        System.err.println("Done computing regularization.");
        if (this.weight != 1.0d) {
            this.cachedValue *= this.weight;
        }
        System.err.println("GE Value = " + this.cachedValue);
    }

    public void setGaussianPriorVariance(double d) {
        this.gpv = d;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        if (this.crf.getWeightsValueChangeStamp() != this.cache) {
            cacheValueAndGradient();
            this.cache = this.crf.getWeightsValueChangeStamp();
        }
        this.cachedGradient.getParameters(dArr);
        if (this.weight != 1.0d) {
            MatrixOps.timesEquals(dArr, this.weight);
        }
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        if (this.crf.getWeightsValueChangeStamp() != this.cache) {
            cacheValueAndGradient();
            this.cache = this.crf.getWeightsValueChangeStamp();
        }
        return this.cachedValue;
    }

    public void shutdown() {
        if (this.executor == null) {
            return;
        }
        this.executor.shutdown();
        try {
            this.executor.awaitTermination(30L, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        if (!$assertionsDisabled && this.executor.shutdownNow().size() != 0) {
            throw new AssertionError("All tasks didn't finish");
        }
    }

    static {
        $assertionsDisabled = !CRFOptimizableByGE.class.desiredAssertionStatus();
    }
}
