package cc.mallet.types;

import cc.mallet.util.Randoms;
import gnu.trove.TIntHashSet;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TIntIterator;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import net.didion.jwnl.dictionary.file.DictionaryFile;
import org.apache.commons.lang3.StringUtils;

/* loaded from: input_file:cc/mallet/types/Dirichlet.class */
public class Dirichlet {
    Alphabet dict;
    double magnitude;
    double[] partition;
    Randoms random;
    public static final double EULER_MASCHERONI = -0.5772156649015329d;
    public static final double PI_SQUARED_OVER_SIX = 1.6449340668482264d;
    public static final double HALF_LOG_TWO_PI;
    public static final double DIGAMMA_COEF_1 = 0.0d;
    public static final double DIGAMMA_COEF_2 = 0.0d;
    public static final double DIGAMMA_COEF_3 = 0.0d;
    public static final double DIGAMMA_COEF_4 = 0.0d;
    public static final double DIGAMMA_COEF_5 = 0.0d;
    public static final double DIGAMMA_COEF_6 = 0.0d;
    public static final double DIGAMMA_COEF_7 = 0.0d;
    public static final double DIGAMMA_COEF_8 = 0.0d;
    public static final double DIGAMMA_COEF_9 = 3.0d;
    public static final double DIGAMMA_COEF_10 = 26.0d;
    public static final double DIGAMMA_LARGE = 9.5d;
    public static final double DIGAMMA_SMALL = 1.0E-6d;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/mallet/types/Dirichlet$Estimator.class */
    public static abstract class Estimator {
        ArrayList<Multinomial> multinomials;

        public Estimator() {
            this.multinomials = new ArrayList<>();
        }

        public Estimator(Collection<Multinomial> collection) {
            this.multinomials = new ArrayList<>(collection);
            for (int i = 1; i < this.multinomials.size(); i++) {
                if (this.multinomials.get(i - 1).size() != this.multinomials.get(i).size() || this.multinomials.get(i - 1).getAlphabet() != this.multinomials.get(i).getAlphabet()) {
                    throw new IllegalArgumentException("All multinomials must have same size and Alphabet.");
                }
            }
        }

        public void addMultinomial(Multinomial multinomial) {
            this.multinomials.add(multinomial);
        }

        public abstract Dirichlet estimate();
    }

    /* loaded from: input_file:cc/mallet/types/Dirichlet$MethodOfMomentsEstimator.class */
    public static class MethodOfMomentsEstimator extends Estimator {
        @Override // cc.mallet.types.Dirichlet.Estimator
        public Dirichlet estimate() {
            double[] dArr = new double[this.multinomials.get(0).size()];
            for (int i = 1; i < this.multinomials.size(); i++) {
                this.multinomials.get(i).addProbabilitiesTo(dArr);
            }
            double d = 0.0d;
            for (double d2 : dArr) {
                d += d2;
            }
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] / d;
            }
            throw new UnsupportedOperationException("Not yet implemented.");
        }
    }

    static {
        $assertionsDisabled = !Dirichlet.class.desiredAssertionStatus();
        HALF_LOG_TWO_PI = Math.log(6.283185307179586d) / 2.0d;
    }

    public Dirichlet(double d, double[] dArr) {
        this.magnitude = 1.0d;
        this.random = null;
        this.magnitude = d;
        this.partition = dArr;
    }

    public Dirichlet(double[] dArr) {
        this.magnitude = 1.0d;
        this.random = null;
        this.magnitude = 0.0d;
        this.partition = new double[dArr.length];
        for (double d : dArr) {
            this.magnitude += d;
        }
        for (int i = 0; i < dArr.length; i++) {
            this.partition[i] = dArr[i] / this.magnitude;
        }
    }

    public Dirichlet(double[] dArr, Alphabet alphabet) {
        this(dArr);
        if (alphabet != null && dArr.length != alphabet.size()) {
            throw new IllegalArgumentException("alphas and dict sizes do not match.");
        }
        this.dict = alphabet;
        if (alphabet != null) {
            alphabet.stopGrowth();
        }
    }

    public Dirichlet(Alphabet alphabet) {
        this(alphabet, 1.0d);
    }

    public Dirichlet(Alphabet alphabet, double d) {
        this(alphabet.size(), d);
        this.dict = alphabet;
        alphabet.stopGrowth();
    }

    public Dirichlet(int i) {
        this(i, 1.0d);
    }

    public Dirichlet(int i, double d) {
        this.magnitude = 1.0d;
        this.random = null;
        this.magnitude = i * d;
        this.partition = new double[i];
        this.partition[0] = 1.0d / i;
        for (int i2 = 1; i2 < i; i2++) {
            this.partition[i2] = this.partition[0];
        }
    }

    private void initRandom() {
        if (this.random == null) {
            this.random = new Randoms();
        }
    }

    public double[] nextDistribution() {
        double[] dArr = new double[this.partition.length];
        initRandom();
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.random.nextGamma(this.partition[i] * this.magnitude, 1.0d);
            if (dArr[i] <= 0.0d) {
                dArr[i] = 1.0E-4d;
            }
            d += dArr[i];
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] / d;
        }
        return dArr;
    }

    public static String distributionToString(double d, double[] dArr) {
        StringBuffer stringBuffer = new StringBuffer();
        NumberFormat numberFormat = NumberFormat.getInstance();
        numberFormat.setMaximumFractionDigits(5);
        stringBuffer.append(String.valueOf(numberFormat.format(d)) + ":\t");
        for (double d2 : dArr) {
            stringBuffer.append(String.valueOf(numberFormat.format(d2)) + "\t");
        }
        return stringBuffer.toString();
    }

    public void toFile(String str) throws IOException {
        PrintWriter printWriter = new PrintWriter(new BufferedWriter(new FileWriter(str)));
        for (int i = 0; i < this.partition.length; i++) {
            printWriter.println(this.magnitude * this.partition[i]);
        }
        printWriter.flush();
        printWriter.close();
    }

    public int[] drawObservation(int i) {
        initRandom();
        return drawObservation(i, nextDistribution());
    }

    public int[] drawObservation(int i, double[] dArr) {
        initRandom();
        int[] iArr = new int[this.partition.length];
        Arrays.fill(iArr, 0);
        int nextPoisson = i < 100 ? this.random.nextPoisson() : (int) Math.round(this.random.nextGaussian(i, i));
        for (int i2 = 0; i2 < nextPoisson; i2++) {
            int nextDiscrete = this.random.nextDiscrete(dArr);
            iArr[nextDiscrete] = iArr[nextDiscrete] + 1;
        }
        return iArr;
    }

    public Object[] drawObservations(int i, int i2) {
        Object[] objArr = new Object[i];
        for (int i3 = 0; i3 < i; i3++) {
            objArr[i3] = drawObservation(i2);
        }
        return objArr;
    }

    public static double logGammaDefinition(double d) {
        double log = ((-0.5772156649015329d) * d) - Math.log(d);
        for (int i = 1; i < 10000000; i++) {
            log += (d / i) - Math.log(1.0d + (d / i));
        }
        return log;
    }

    public static double logGammaDifference(double d, int i) {
        double d2 = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            d2 += Math.log(d + i2);
        }
        return d2;
    }

    public static double logGamma(double d) {
        return logGammaStirling(d);
    }

    public static double logGammaStirling(double d) {
        int i = 0;
        while (d < 2.0d) {
            d += 1.0d;
            i++;
        }
        double log = ((((HALF_LOG_TWO_PI + ((d - 0.5d) * Math.log(d))) - d) + (1.0d / (12.0d * d))) - (1.0d / (((360.0d * d) * d) * d))) + (1.0d / (((((1260.0d * d) * d) * d) * d) * d));
        while (true) {
            double d2 = log;
            if (i <= 0) {
                return d2;
            }
            i--;
            d -= 1.0d;
            log = d2 - Math.log(d);
        }
    }

    public static double logGammaNemes(double d) {
        return (HALF_LOG_TWO_PI - (Math.log(d) / 2.0d)) + (d * (Math.log(d + (1.0d / ((12.0d * d) - (1.0d / (10.0d * d))))) - 1.0d));
    }

    public static double digamma(double d) {
        double d2 = 0.0d;
        if (d < 1.0E-6d) {
            return (-0.5772156649015329d) - (1.0d / d);
        }
        while (d < 9.5d) {
            d2 -= 1.0d / d;
            d += 1.0d;
        }
        double d3 = 1.0d / d;
        double d4 = d3 * d3;
        return d2 + ((Math.log(d) - (0.5d * d3)) - (d4 * (0.0d - (d4 * (0.0d - (d4 * (0.0d - (d4 * (0.0d - (d4 * (0.0d - (d4 * (0.0d - (d4 * 0.0d))))))))))))));
    }

    public static double digammaDifference(double d, int i) {
        double d2 = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            d2 += 1.0d / (d + i2);
        }
        return d2;
    }

    public static double trigamma(double d) {
        int i = 0;
        while (d < 2.0d) {
            d += 1.0d;
            i++;
        }
        double d2 = 1.0d / d;
        double d3 = d2 * d2;
        double d4 = ((((d2 + (0.5d * d3)) + ((0.1666667d * d3) * d2)) - (((0.03333333d * d3) * d3) * d2)) + ((((0.02380952d * d3) * d3) * d3) * d2)) - (((((0.03333333d * d3) * d3) * d3) * d3) * d2);
        while (true) {
            double d5 = d4;
            if (i <= 0) {
                return d5;
            }
            i--;
            d -= 1.0d;
            d4 = d5 + (1.0d / (d * d));
        }
    }

    public static double learnSymmetricConcentration(int[] iArr, int[] iArr2, int i, double d) {
        int i2 = 0;
        int[] iArr3 = new int[iArr2.length];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (iArr[i3] > 0) {
                i2 = i3;
            }
        }
        int i4 = 0;
        for (int i5 = 0; i5 < iArr2.length; i5++) {
            if (iArr2[i5] > 0) {
                iArr3[i4] = i5;
                i4++;
            }
        }
        int i6 = i4;
        for (int i7 = 1; i7 <= 200; i7++) {
            double d2 = d / i;
            double d3 = 0.0d;
            double d4 = 0.0d;
            for (int i8 = 1; i8 <= i2; i8++) {
                d3 += 1.0d / ((d2 + i8) - 1.0d);
                d4 += iArr[i8] * d3;
            }
            double d5 = 0.0d;
            double d6 = 0.0d;
            double digamma = digamma(d);
            for (int i9 = 0; i9 < i6; i9++) {
                int i10 = iArr3[i9];
                if (i10 - 0 > 20) {
                    d5 = digamma(d + i10) - digamma;
                } else {
                    for (int i11 = 0; i11 < i10; i11++) {
                        d5 += 1.0d / (d + i11);
                    }
                }
                d6 += d5 * iArr2[i10];
            }
            d = (d2 * d4) / d6;
        }
        return d;
    }

    public static void testSymmetricConcentration(int i, int i2, int i3) {
        Math.log(i);
        for (int i4 = -5; i4 < 4; i4++) {
            double d = i * 1.0d;
            Dirichlet dirichlet = new Dirichlet(i, d / i);
            int[] iArr = new int[1000000];
            int[] iArr2 = new int[1000000];
            Object[] drawObservations = dirichlet.drawObservations(i2, i3);
            Dirichlet dirichlet2 = new Dirichlet(i, 1.0d);
            dirichlet2.learnParametersWithHistogram(drawObservations);
            System.out.println(dirichlet2.magnitude);
            for (int i5 = 0; i5 < i2; i5++) {
                int[] iArr3 = (int[]) drawObservations[i5];
                int i6 = 0;
                for (int i7 = 0; i7 < i; i7++) {
                    if (iArr3[i7] > 0) {
                        i6 += iArr3[i7];
                        int i8 = iArr3[i7];
                        iArr[i8] = iArr[i8] + 1;
                    }
                }
                int i9 = i6;
                iArr2[i9] = iArr2[i9] + 1;
            }
            double learnSymmetricConcentration = learnSymmetricConcentration(iArr, iArr2, i, 1.0d);
            System.out.println(String.valueOf(d) + "\t" + learnSymmetricConcentration + "\t" + Math.abs(d - learnSymmetricConcentration));
        }
    }

    public static double learnParameters(double[] dArr, int[][] iArr, int[] iArr2) {
        return learnParameters(dArr, iArr, iArr2, 1.00001d, 1.0d, 200);
    }

    public static double learnParameters(double[] dArr, int[][] iArr, int[] iArr2, double d, double d2, int i) {
        double d3 = 0.0d;
        for (double d4 : dArr) {
            d3 += d4;
        }
        int[] iArr3 = new int[iArr.length];
        Arrays.fill(iArr3, -1);
        for (int i2 = 0; i2 < iArr.length; i2++) {
            int[] iArr4 = iArr[i2];
            for (int i3 = 0; i3 < iArr4.length; i3++) {
                if (iArr4[i3] > 0) {
                    iArr3[i2] = i3;
                }
            }
        }
        for (int i4 = 0; i4 < i; i4++) {
            double d5 = 0.0d;
            double d6 = 0.0d;
            for (int i5 = 1; i5 < iArr2.length; i5++) {
                d6 += 1.0d / ((d3 + i5) - 1.0d);
                d5 += iArr2[i5] * d6;
            }
            double d7 = d5 - (1.0d / d2);
            d3 = 0.0d;
            for (int i6 = 0; i6 < dArr.length; i6++) {
                int i7 = iArr3[i6];
                double d8 = dArr[i6];
                dArr[i6] = 0.0d;
                double d9 = 0.0d;
                int[] iArr5 = iArr[i6];
                for (int i8 = 1; i8 <= i7; i8++) {
                    d9 += 1.0d / ((d8 + i8) - 1.0d);
                    int i9 = i6;
                    dArr[i9] = dArr[i9] + (iArr5[i8] * d9);
                }
                dArr[i6] = (d8 * (dArr[i6] + d)) / d7;
                d3 += dArr[i6];
            }
        }
        if (d3 < 0.0d) {
            throw new RuntimeException("sum: " + d3);
        }
        return d3;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v11, types: [int[], int[][]] */
    public long learnParametersWithHistogram(Object[] objArr) {
        int i = 0;
        int[] iArr = new int[this.partition.length];
        Arrays.fill(iArr, 0);
        for (Object obj : objArr) {
            int i2 = 0;
            int[] iArr2 = (int[]) obj;
            for (int i3 = 0; i3 < iArr2.length; i3++) {
                if (iArr2[i3] > iArr[i3]) {
                    iArr[i3] = iArr2[i3];
                }
                i2 += iArr2[i3];
            }
            if (i2 > i) {
                i = i2;
            }
        }
        ?? r0 = new int[this.partition.length];
        for (int i4 = 0; i4 < this.partition.length; i4++) {
            r0[i4] = new int[iArr[i4] + 1];
            Arrays.fill(r0[i4], 0);
        }
        int[] iArr3 = new int[i + 1];
        Arrays.fill(iArr3, 0);
        for (Object obj2 : objArr) {
            int i5 = 0;
            int[] iArr4 = (int[]) obj2;
            for (int i6 = 0; i6 < iArr4.length; i6++) {
                int[] iArr5 = r0[i6];
                int i7 = iArr4[i6];
                iArr5[i7] = iArr5[i7] + 1;
                i5 += iArr4[i6];
            }
            int i8 = i5;
            iArr3[i8] = iArr3[i8] + 1;
        }
        return learnParametersWithHistogram(r0, iArr3);
    }

    public long learnParametersWithHistogram(int[][] iArr, int[] iArr2) {
        long currentTimeMillis = System.currentTimeMillis();
        double[] dArr = new double[this.partition.length];
        double d = 0.0d;
        for (int i = 0; i < this.partition.length; i++) {
            dArr[i] = this.magnitude * this.partition[i];
            d += dArr[i];
        }
        for (int i2 = 0; i2 < 1000; i2++) {
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i3 = 1; i3 < iArr2.length; i3++) {
                d3 += 1.0d / ((d + i3) - 1.0d);
                d2 += iArr2[i3] * d3;
            }
            if (!$assertionsDisabled && d2 <= 0.0d) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && Double.isNaN(d2)) {
                throw new AssertionError();
            }
            d = 0.0d;
            for (int i4 = 0; i4 < this.partition.length; i4++) {
                double d4 = dArr[i4];
                dArr[i4] = 0.0d;
                double d5 = 0.0d;
                int[] iArr3 = iArr[i4];
                if (iArr3.length <= 1) {
                    dArr[i4] = 1.0E-6d;
                } else {
                    for (int i5 = 1; i5 < iArr3.length; i5++) {
                        d5 += 1.0d / ((d4 + i5) - 1.0d);
                        int i6 = i4;
                        dArr[i6] = dArr[i6] + (iArr3[i5] * d5);
                    }
                }
                if (dArr[i4] <= 0.0d) {
                    System.out.println("length of empty array: " + new int[0].length);
                    for (int i7 : iArr3) {
                        System.out.print(String.valueOf(i7) + StringUtils.SPACE);
                    }
                    System.out.println();
                }
                if (!$assertionsDisabled && dArr[i4] <= 0.0d) {
                    throw new AssertionError();
                }
                if (!$assertionsDisabled && Double.isNaN(dArr[i4])) {
                    throw new AssertionError();
                }
                int i8 = i4;
                dArr[i8] = dArr[i8] * (d4 / d2);
                d += dArr[i4];
            }
        }
        for (int i9 = 0; i9 < this.partition.length; i9++) {
            this.partition[i9] = dArr[i9] / d;
            this.magnitude = d;
        }
        return System.currentTimeMillis() - currentTimeMillis;
    }

    public long learnParametersWithDigamma(Object[] objArr) {
        int[][] iArr = new int[this.partition.length][objArr.length];
        int[] iArr2 = new int[objArr.length];
        for (int i = 0; i < objArr.length; i++) {
            int[] iArr3 = (int[]) objArr[i];
            for (int i2 = 0; i2 < this.partition.length; i2++) {
                iArr[i2][i] = iArr3[i2];
                int i3 = i;
                iArr2[i3] = iArr2[i3] + iArr3[i2];
            }
        }
        return learnParametersWithDigamma(iArr, iArr2);
    }

    public long learnParametersWithDigamma(int[][] iArr, int[] iArr2) {
        long currentTimeMillis = System.currentTimeMillis();
        double[] dArr = new double[this.partition.length];
        for (int i = 0; i < 1000; i++) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i2 : iArr2) {
                d2 += digamma(this.magnitude + i2);
            }
            double length = d2 - (iArr2.length * digamma(this.magnitude));
            for (int i3 = 0; i3 < this.partition.length; i3++) {
                dArr[i3] = 0.0d;
                int[] iArr3 = iArr[i3];
                double d3 = this.magnitude * this.partition[i3];
                double digamma = digamma(d3);
                for (int i4 = 0; i4 < iArr3.length; i4++) {
                    if (iArr3[i4] == 0) {
                        int i5 = i3;
                        dArr[i5] = dArr[i5] + digamma;
                    } else {
                        int i6 = i3;
                        dArr[i6] = dArr[i6] + digamma(d3 + iArr3[i4]);
                    }
                }
                int i7 = i3;
                dArr[i7] = dArr[i7] - (iArr3.length * digamma);
                if (dArr[i3] <= 0.0d) {
                    dArr[i3] = 1.0E-6d;
                } else {
                    int i8 = i3;
                    dArr[i8] = dArr[i8] * (d3 / length);
                }
                if (dArr[i3] <= 0.0d) {
                    System.out.println(String.valueOf(dArr[i3]) + "\t" + d3 + "\t" + length);
                }
                if (!$assertionsDisabled && dArr[i3] <= 0.0d) {
                    throw new AssertionError();
                }
                if (!$assertionsDisabled && Double.isNaN(dArr[i3])) {
                    throw new AssertionError();
                }
                d += dArr[i3];
            }
            this.magnitude = d;
            for (int i9 = 0; i9 < this.partition.length; i9++) {
                this.partition[i9] = dArr[i9] / this.magnitude;
            }
        }
        return System.currentTimeMillis() - currentTimeMillis;
    }

    public long learnParametersWithMoments(Object[] objArr) {
        long currentTimeMillis = System.currentTimeMillis();
        int[] iArr = new int[objArr.length];
        double[] dArr = new double[this.partition.length];
        Arrays.fill(this.partition, 0.0d);
        Arrays.fill(iArr, 0);
        Arrays.fill(dArr, 0.0d);
        for (int i = 0; i < objArr.length; i++) {
            int[] iArr2 = (int[]) objArr[i];
            for (int i2 = 0; i2 < this.partition.length; i2++) {
                int i3 = i;
                iArr[i3] = iArr[i3] + iArr2[i2];
            }
            for (int i4 = 0; i4 < this.partition.length; i4++) {
                double[] dArr2 = this.partition;
                int i5 = i4;
                dArr2[i5] = dArr2[i5] + (iArr2[i4] / iArr[i]);
            }
        }
        for (int i6 = 0; i6 < this.partition.length; i6++) {
            double[] dArr3 = this.partition;
            int i7 = i6;
            dArr3[i7] = dArr3[i7] / objArr.length;
        }
        for (int i8 = 0; i8 < objArr.length; i8++) {
            int[] iArr3 = (int[]) objArr[i8];
            for (int i9 = 0; i9 < this.partition.length; i9++) {
                double d = (iArr3[i9] / iArr[i8]) - this.partition[i9];
                int i10 = i9;
                dArr[i10] = dArr[i10] + (d * d);
            }
        }
        for (int i11 = 0; i11 < this.partition.length; i11++) {
            int i12 = i11;
            dArr[i12] = dArr[i12] / (objArr.length - 1);
        }
        double d2 = 0.0d;
        for (int i13 = 0; i13 < this.partition.length; i13++) {
            if (this.partition[i13] != 0.0d) {
                d2 += Math.log(((this.partition[i13] * (1.0d - this.partition[i13])) / dArr[i13]) - 1.0d);
            }
        }
        this.magnitude = Math.exp(d2 / (this.partition.length - 1));
        return System.currentTimeMillis() - currentTimeMillis;
    }

    public long learnParametersWithLeaveOneOut(Object[] objArr) {
        int[][] iArr = new int[this.partition.length][objArr.length];
        int[] iArr2 = new int[objArr.length];
        for (int i = 0; i < objArr.length; i++) {
            int[] iArr3 = (int[]) objArr[i];
            for (int i2 = 0; i2 < this.partition.length; i2++) {
                iArr[i2][i] = iArr3[i2];
                int i3 = i;
                iArr2[i3] = iArr2[i3] + iArr3[i2];
            }
        }
        return learnParametersWithLeaveOneOut(iArr, iArr2);
    }

    public long learnParametersWithLeaveOneOut(int[][] iArr, int[] iArr2) {
        long currentTimeMillis = System.currentTimeMillis();
        double[] dArr = new double[this.partition.length];
        double[] dArr2 = new double[this.partition.length];
        for (int i = 0; i < 1000; i++) {
            double d = 0.0d;
            Arrays.fill(dArr2, 0.0d);
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                d += iArr2[i2] / ((iArr2[i2] - 1) + this.magnitude);
            }
            for (int i3 = 0; i3 < this.partition.length; i3++) {
                int[] iArr3 = iArr[i3];
                for (int i4 = 0; i4 < iArr3.length; i4++) {
                    if (iArr3[i4] >= 2) {
                        int i5 = i3;
                        dArr2[i5] = dArr2[i5] + (iArr3[i4] / ((iArr3[i4] - 1) + (this.magnitude * this.partition[i3])));
                    }
                }
            }
            double d2 = 0.0d;
            for (int i6 = 0; i6 < this.partition.length; i6++) {
                if (dArr2[i6] == 0.0d) {
                    dArr[i6] = 1.0E-6d;
                } else {
                    dArr[i6] = ((this.partition[i6] * this.magnitude) * dArr2[i6]) / d;
                }
                d2 += dArr[i6];
            }
            for (int i7 = 0; i7 < this.partition.length; i7++) {
                this.partition[i7] = dArr[i7] / d2;
            }
            this.magnitude = d2;
        }
        return System.currentTimeMillis() - currentTimeMillis;
    }

    public double absoluteDifference(Dirichlet dirichlet) {
        if (this.partition.length != dirichlet.partition.length) {
            throw new IllegalArgumentException("dirichlets must have the same dimension to be compared");
        }
        double d = 0.0d;
        for (int i = 0; i < this.partition.length; i++) {
            d += Math.abs((this.partition[i] * this.magnitude) - (dirichlet.partition[i] * dirichlet.magnitude));
        }
        return d;
    }

    public double squaredDifference(Dirichlet dirichlet) {
        if (this.partition.length != dirichlet.partition.length) {
            throw new IllegalArgumentException("dirichlets must have the same dimension to be compared");
        }
        double d = 0.0d;
        for (int i = 0; i < this.partition.length; i++) {
            d += Math.pow((this.partition[i] * this.magnitude) - (dirichlet.partition[i] * dirichlet.magnitude), 2.0d);
        }
        return d;
    }

    public void checkBreakeven(double d) {
        double digamma = digamma(d);
        for (int i = 1; i < 100; i++) {
            long currentTimeMillis = System.currentTimeMillis();
            for (int i2 = 0; i2 < 1000000; i2++) {
                digamma(d + i);
            }
            long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
            long currentTimeMillis3 = System.currentTimeMillis();
            for (int i3 = 0; i3 < 1000000; i3++) {
                digammaDifference(d, i);
            }
            long currentTimeMillis4 = System.currentTimeMillis() - currentTimeMillis3;
            System.out.println(String.valueOf(i) + "\tdirect: " + currentTimeMillis2 + "\tindirect: " + currentTimeMillis4 + " (" + (currentTimeMillis2 - currentTimeMillis4) + ")");
            System.out.println(DictionaryFile.COMMENT_HEADER + (digamma(d + i) - digamma) + StringUtils.SPACE + digammaDifference(d, i));
        }
    }

    public static String compare(double d, int i, int i2, int i3) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append(String.valueOf(d) + "\t" + i + "\t" + i2 + "\t" + i3 + "\t");
        Dirichlet dirichlet = new Dirichlet(d, new Dirichlet(i, d / i).nextDistribution());
        Object[] drawObservations = dirichlet.drawObservations(i2, i3);
        Dirichlet dirichlet2 = new Dirichlet(i, d / i);
        stringBuffer.append(String.valueOf(dirichlet2.learnParametersWithDigamma(drawObservations)) + "\t" + dirichlet.absoluteDifference(dirichlet2) + "\t");
        Dirichlet dirichlet3 = new Dirichlet(i, d / i);
        stringBuffer.append(String.valueOf(dirichlet3.learnParametersWithHistogram(drawObservations)) + "\t" + dirichlet.absoluteDifference(dirichlet3) + "\t");
        Dirichlet dirichlet4 = new Dirichlet(i, d / i);
        stringBuffer.append(String.valueOf(dirichlet4.learnParametersWithMoments(drawObservations)) + "\t" + dirichlet.absoluteDifference(dirichlet4) + "\t");
        Dirichlet dirichlet5 = new Dirichlet(i, d / i);
        stringBuffer.append(String.valueOf(dirichlet5.learnParametersWithLeaveOneOut(drawObservations)) + "\t" + dirichlet.absoluteDifference(dirichlet5) + "\t");
        return stringBuffer.toString();
    }

    public static double dirichletMultinomialLikelihoodRatio(TIntIntHashMap tIntIntHashMap, TIntIntHashMap tIntIntHashMap2, double d, double d2) {
        double d3 = 0.0d;
        logGamma(d);
        int i = 0;
        int i2 = 0;
        TIntHashSet tIntHashSet = new TIntHashSet();
        tIntHashSet.addAll(tIntIntHashMap.keys());
        tIntHashSet.addAll(tIntIntHashMap2.keys());
        TIntIterator it = tIntHashSet.iterator();
        while (it.hasNext()) {
            int next = it.next();
            int i3 = 0;
            if (tIntIntHashMap.containsKey(next)) {
                i3 = tIntIntHashMap.get(next);
            }
            int i4 = 0;
            if (tIntIntHashMap2.containsKey(next)) {
                i4 = tIntIntHashMap2.get(next);
            }
            i += i3;
            i2 += i4;
            d3 += ((logGamma(d) + logGamma((d + i3) + i4)) - logGamma(d + i3)) - logGamma(d + i4);
        }
        return d3 + (((logGamma(d2 + i) + logGamma(d2 + i2)) - logGamma(d2)) - logGamma((d2 + i) + i2));
    }

    public static double dirichletMultinomialLikelihoodRatio(int[] iArr, int[] iArr2, double d, double d2) {
        if (iArr.length != iArr2.length) {
            throw new IllegalArgumentException("both arrays must contain the same number of dimensions");
        }
        double d3 = 0.0d;
        double logGamma = logGamma(d);
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            int i4 = iArr[i3];
            int i5 = iArr2[i3];
            i += i4;
            i2 += i5;
            d3 += ((logGamma + logGamma((d + i4) + i5)) - logGamma(d + i4)) - logGamma(d + i5);
        }
        return d3 + (((logGamma(d2 + i) + logGamma(d2 + i2)) - logGamma(d2)) - logGamma((d2 + i) + i2));
    }

    public double dirichletMultinomialLikelihoodRatio(int[] iArr, int[] iArr2) {
        if (iArr.length != iArr2.length || iArr.length != this.partition.length) {
            throw new IllegalArgumentException("both arrays and the Dirichlet prior must contain the same number of dimensions");
        }
        double d = 0.0d;
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            int i4 = iArr[i3];
            int i5 = iArr2[i3];
            i += i4;
            i2 += i5;
            double d2 = this.partition[i3] * this.magnitude;
            d += ((logGamma(d2) + logGamma((d2 + i4) + i5)) - logGamma(d2 + i4)) - logGamma(d2 + i5);
        }
        return d + (((logGamma(this.magnitude + i) + logGamma(this.magnitude + i2)) - logGamma(this.magnitude)) - logGamma((this.magnitude + i) + i2));
    }

    public static double ewensLikelihoodRatio(int[] iArr, int[] iArr2, double d) {
        if (iArr.length != iArr2.length) {
            throw new IllegalArgumentException("both arrays must contain the same number of dimensions");
        }
        double d2 = 0.0d;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < iArr.length; i4++) {
            int i5 = iArr[i4];
            int i6 = iArr2[i4];
            i += i5;
            i2 += i6;
            i3 += i5 + i6;
        }
        int[] iArr3 = new int[i3 + 1];
        int[] iArr4 = new int[i3 + 1];
        int[] iArr5 = new int[i3 + 1];
        for (int i7 = 0; i7 < iArr.length; i7++) {
            int i8 = iArr[i7];
            int i9 = iArr2[i7];
            iArr3[i8] = iArr3[i8] + 1;
            iArr3[i9] = iArr3[i9] + 1;
            int i10 = i8 + i9;
            iArr5[i10] = iArr5[i10] + 1;
        }
        for (int i11 = 1; i11 <= i3; i11++) {
            if (iArr3[i11] != 0 || iArr4[i11] != 0 || iArr5[i11] != 0) {
                d2 = d2 + (((iArr5[i11] - iArr3[i11]) - iArr4[i11]) * Math.log(d / i11)) + ((logGamma(iArr3[i11] + 1) + logGamma(iArr4[i11] + 1)) - logGamma(iArr5[i11] + 1));
            }
        }
        return d2 + ((logGamma(i3 + 1) - logGamma(i + 1)) - logGamma(i2 + 1)) + (((logGamma(d + i) + logGamma(d + i2)) - logGamma(d)) - logGamma((d + i) + i2));
    }

    public static void runComparison() {
        try {
            PrintWriter printWriter = new PrintWriter(new BufferedWriter(new FileWriter("comparison")));
            int i = 10;
            for (int i2 = 0; i2 < 5; i2++) {
                int i3 = 100;
                for (int i4 = 0; i4 < 5; i4++) {
                    int i5 = 100;
                    for (int i6 = 0; i6 < 5; i6++) {
                        System.out.println(String.valueOf(i) + "\t" + i + "\t" + i3 + "\t" + i5);
                        for (int i7 = 0; i7 < 10; i7++) {
                            printWriter.println(compare(i, i, i3, i5));
                        }
                        printWriter.flush();
                        i5 *= 2;
                    }
                    i3 *= 2;
                }
                i *= 2;
            }
            printWriter.flush();
            printWriter.close();
        } catch (Exception e) {
            e.printStackTrace(System.out);
        }
    }

    public static void main(String[] strArr) {
        testSymmetricConcentration(1000, 100, 1000);
    }

    public Alphabet getAlphabet() {
        return this.dict;
    }

    public int size() {
        return this.partition.length;
    }

    public double alpha(int i) {
        return this.magnitude * this.partition[i];
    }

    public void print() {
        System.out.println("Dirichlet:");
        for (int i = 0; i < this.partition.length; i++) {
            System.out.println(this.dict != null ? this.dict.lookupObject(i).toString() : String.valueOf(i) + "=" + (this.magnitude * this.partition[i]));
        }
    }

    protected double[] randomRawMultinomial(Randoms randoms) {
        double d = 0.0d;
        double[] dArr = new double[this.partition.length];
        for (int i = 0; i < this.partition.length; i++) {
            dArr[i] = randoms.nextGamma(this.magnitude * this.partition[i]);
            d += dArr[i];
        }
        for (int i2 = 0; i2 < this.partition.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] / d;
        }
        return dArr;
    }

    public Multinomial randomMultinomial(Randoms randoms) {
        return new Multinomial(randomRawMultinomial(randoms), this.dict, this.partition.length, false, false);
    }

    public Dirichlet randomDirichlet(Randoms randoms, double d) {
        double[] randomRawMultinomial = randomRawMultinomial(randoms);
        double length = randomRawMultinomial.length * d;
        for (int i = 0; i < randomRawMultinomial.length; i++) {
            int i2 = i;
            randomRawMultinomial[i2] = randomRawMultinomial[i2] * length;
        }
        return new Dirichlet(randomRawMultinomial, this.dict);
    }

    public FeatureSequence randomFeatureSequence(Randoms randoms, int i) {
        return randomMultinomial(randoms).randomFeatureSequence(randoms, i);
    }

    public FeatureVector randomFeatureVector(Randoms randoms, int i) {
        return new FeatureVector(randomFeatureSequence(randoms, i));
    }

    public TokenSequence randomTokenSequence(Randoms randoms, int i) {
        FeatureSequence randomFeatureSequence = randomFeatureSequence(randoms, i);
        TokenSequence tokenSequence = new TokenSequence(i);
        for (int i2 = 0; i2 < i; i2++) {
            tokenSequence.add(randomFeatureSequence.getObjectAtPosition(i2).toString());
        }
        return tokenSequence;
    }

    public double[] randomVector(Randoms randoms) {
        return randomRawMultinomial(randoms);
    }
}
