package edu.umass.cs.mallet.base.types;

import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.types.RankedFeatureVector;

/* loaded from: input_file:edu/umass/cs/mallet/base/types/PerLabelInfoGain.class */
public class PerLabelInfoGain {
    static final float log2;
    static boolean binary;
    static boolean print;
    InfoGain[] ig;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/umass/cs/mallet/base/types/PerLabelInfoGain$Factory.class */
    public static class Factory implements RankedFeatureVector.PerLabelFactory {
        @Override // edu.umass.cs.mallet.base.types.RankedFeatureVector.PerLabelFactory
        public RankedFeatureVector[] newRankedFeatureVectors(InstanceList instanceList) {
            return new PerLabelInfoGain(instanceList).ig;
        }
    }

    public PerLabelInfoGain(InstanceList instanceList) {
        double[][] calcPerLabelInfoGains = calcPerLabelInfoGains(instanceList);
        Alphabet dataAlphabet = instanceList.getDataAlphabet();
        int size = instanceList.getTargetAlphabet().size();
        this.ig = new InfoGain[size];
        for (int i = 0; i < size; i++) {
            this.ig[i] = new InfoGain(dataAlphabet, calcPerLabelInfoGains[i]);
        }
    }

    public InfoGain getInfoGain(int i) {
        return this.ig[i];
    }

    public int getNumClasses() {
        return this.ig.length;
    }

    private static double entropy(double d, double d2) {
        if (!$assertionsDisabled && Math.abs((d + d2) - 1.0d) >= 1.0E-4d) {
            throw new AssertionError("pc=" + d + " pnc=" + d2);
        }
        if (d == Transducer.ZERO_COST || d2 == Transducer.ZERO_COST) {
            return Transducer.ZERO_COST;
        }
        float log = (float) ((((-d) * Math.log(d)) / log2) - ((d2 * Math.log(d2)) / log2));
        if ($assertionsDisabled || log >= 0.0f) {
            return log;
        }
        throw new AssertionError("pc=" + d + " pnc=" + d2);
    }

    public static double[][] calcPerLabelInfoGains(InstanceList instanceList) {
        if (!$assertionsDisabled && !binary) {
            throw new AssertionError();
        }
        int size = instanceList.getTargetAlphabet().size();
        int size2 = instanceList.getDataAlphabet().size();
        int size3 = instanceList.size();
        double[][] dArr = new double[size][size2];
        int[] iArr = new int[size2];
        int[] iArr2 = new int[size];
        for (int i = 0; i < instanceList.size(); i++) {
            Instance instanceList2 = instanceList.getInstance(i);
            FeatureVector featureVector = (FeatureVector) instanceList2.getData();
            int bestIndex = instanceList2.getLabeling().getBestIndex();
            iArr2[bestIndex] = iArr2[bestIndex] + 1;
            for (int i2 = 0; i2 < featureVector.numLocations(); i2++) {
                int indexAtLocation = featureVector.indexAtLocation(i2);
                double[] dArr2 = dArr[bestIndex];
                dArr2[indexAtLocation] = dArr2[indexAtLocation] + 1.0d;
                iArr[indexAtLocation] = iArr[indexAtLocation] + 1;
                if (!$assertionsDisabled && iArr[indexAtLocation] > size3) {
                    throw new AssertionError("fi=" + indexAtLocation + "ni=" + size3 + " fc=" + iArr[indexAtLocation] + " i=" + i);
                }
            }
        }
        Alphabet dataAlphabet = instanceList.getDataAlphabet();
        if (print) {
            for (int i3 = 0; i3 < size; i3++) {
                System.out.println(instanceList.getTargetAlphabet().lookupObject(i3).toString() + "=" + i3);
            }
        }
        double[] dArr3 = new double[size];
        for (int i4 = 0; i4 < size; i4++) {
            dArr3[i4] = entropy(iArr2[i4] / size3, (size3 - iArr2[i4]) / size3);
        }
        for (int i5 = 0; i5 < size2; i5++) {
            double d = iArr[i5] / size3;
            double d2 = (size3 - iArr[i5]) / size3;
            if (!$assertionsDisabled && d < Transducer.ZERO_COST) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && d2 < Transducer.ZERO_COST) {
                throw new AssertionError();
            }
            if (print && i5 < 10000) {
                System.out.print(dataAlphabet.lookupObject(i5).toString());
                for (int i6 = 0; i6 < size; i6++) {
                    System.out.print(" " + dArr[i6][i5]);
                }
                System.out.println("");
            }
            for (int i7 = 0; i7 < size; i7++) {
                if (iArr[i5] == 0) {
                    dArr[i7][i5] = 0.0d;
                } else {
                    double entropy = entropy(dArr[i7][i5] / iArr[i5], (iArr[i5] - dArr[i7][i5]) / iArr[i5]);
                    double entropy2 = entropy((iArr2[i7] - dArr[i7][i5]) / (size3 - iArr[i5]), ((size3 - iArr[i5]) - (iArr2[i7] - dArr[i7][i5])) / (size3 - iArr[i5]));
                    dArr[i7][i5] = dArr3[i7] - ((d * entropy) + (d2 * entropy2));
                    if (print && i5 < 10000) {
                        System.out.println("pf=" + d + " ef=" + entropy + " pnf=" + d2 + " enf=" + entropy2 + " e=" + dArr3[i7] + " cig=" + dArr[i7][i5]);
                    }
                }
            }
        }
        if (print) {
            for (int i8 = 0; i8 < 100; i8++) {
                String obj = dataAlphabet.lookupObject(i8).toString();
                for (int i9 = 0; i9 < size; i9++) {
                    String obj2 = instanceList.getTargetAlphabet().lookupObject(i9).toString();
                    if (dArr[i9][i8] > 0.1d) {
                        System.out.println(obj + ',' + obj2 + '=' + dArr[i9][i8]);
                    }
                }
            }
        }
        return dArr;
    }

    static {
        $assertionsDisabled = !PerLabelInfoGain.class.desiredAssertionStatus();
        log2 = (float) Math.log(2.0d);
        binary = true;
        print = false;
    }
}
