package net.jkernelmachines.density;

import java.util.Arrays;
import java.util.List;
import java.util.Random;
import net.jkernelmachines.util.DebugPrinter;
import net.jkernelmachines.util.algebra.VectorOperations;

/* loaded from: input_file:net/jkernelmachines/density/DoubleKMeans.class */
public class DoubleKMeans implements DensityFunction<double[]> {
    private static final long serialVersionUID = -376280133933635170L;
    int K;
    double[][] means;
    double shiftRatio = 20.0d;
    DebugPrinter debug = new DebugPrinter();

    public DoubleKMeans(int i) {
        this.K = i;
    }

    @Override // net.jkernelmachines.density.DensityFunction
    public void train(double[] dArr) {
        throw new UnsupportedOperationException("Training on a single sample is not supported");
    }

    @Override // net.jkernelmachines.density.DensityFunction
    public void train(List<double[]> list) {
        int size = list.size();
        int length = list.get(0).length;
        if (this.K > size) {
            throw new ArithmeticException("Too few data points: " + size + " < " + this.K);
        }
        double[] dArr = new double[this.K];
        double[][] dArr2 = new double[this.K][length];
        Random random = new Random();
        int[] iArr = new int[size];
        for (int i = 0; i < size; i++) {
            iArr[i] = random.nextInt(this.K);
            VectorOperations.addi(dArr2[iArr[i]], dArr2[iArr[i]], 1.0d, list.get(i));
            int i2 = iArr[i];
            dArr[i2] = dArr[i2] + 1.0d;
        }
        for (int i3 = 0; i3 < this.K; i3++) {
            if (dArr[i3] > 0.0d) {
                VectorOperations.muli(dArr2[i3], dArr2[i3], 1.0d / dArr[i3]);
            } else {
                Arrays.fill(dArr2[i3], 0.0d);
            }
        }
        for (int i4 = 0; i4 < 10000; i4++) {
            boolean z = false;
            for (int i5 = 0; i5 < size; i5++) {
                double[] dArr3 = list.get(i5);
                double d = Double.POSITIVE_INFINITY;
                int i6 = -1;
                for (int i7 = 0; i7 < this.K; i7++) {
                    double d2p2 = VectorOperations.d2p2(dArr3, dArr2[i7]);
                    if (d2p2 < d) {
                        i6 = i7;
                        d = d2p2;
                    }
                }
                if (i6 != iArr[i5]) {
                    z = true;
                }
                iArr[i5] = i6;
            }
            for (int i8 = 0; i8 < this.K; i8++) {
                Arrays.fill(dArr2[i8], 0.0d);
                dArr[i8] = 0.0d;
            }
            for (int i9 = 0; i9 < size; i9++) {
                VectorOperations.addi(dArr2[iArr[i9]], dArr2[iArr[i9]], 1.0d, list.get(i9));
                int i10 = iArr[i9];
                dArr[i10] = dArr[i10] + 1.0d;
            }
            for (int i11 = 0; i11 < this.K; i11++) {
                if (dArr[i11] > 0.0d) {
                    VectorOperations.muli(dArr2[i11], dArr2[i11], 1.0d / dArr[i11]);
                } else {
                    Arrays.fill(dArr2[i11], 0.0d);
                }
            }
            if (!z) {
                double[] dArr4 = new double[this.K];
                double d2 = 0.0d;
                for (int i12 = 0; i12 < size; i12++) {
                    double d2p22 = VectorOperations.d2p2(list.get(i12), dArr2[iArr[i12]]);
                    int i13 = iArr[i12];
                    dArr4[i13] = dArr4[i13] + d2p22;
                    d2 += d2p22;
                }
                this.debug.println(3, "d: " + Arrays.toString(dArr4));
                this.debug.println(2, "total dist: " + d2);
                double d3 = Double.POSITIVE_INFINITY;
                double d4 = -1.0d;
                int i14 = -1;
                int i15 = -1;
                for (int i16 = 0; i16 < this.K; i16++) {
                    if (dArr4[i16] < d3) {
                        d3 = dArr4[i16];
                        i14 = i16;
                    }
                    if (dArr4[i16] > d4) {
                        d4 = dArr4[i16];
                        i15 = i16;
                    }
                }
                this.debug.println(3, "dmin: " + d3 + "\tdmax: " + d4);
                if (d3 != 0.0d && d4 / d3 <= this.shiftRatio) {
                    break;
                }
                int nextInt = random.nextInt(length);
                dArr2[i14] = Arrays.copyOf(dArr2[i15], length);
                double n2 = VectorOperations.n2(dArr2[i15]);
                double[] dArr5 = dArr2[i14];
                dArr5[nextInt] = dArr5[nextInt] + (1.0E-6d * n2);
                double[] dArr6 = dArr2[i15];
                dArr6[nextInt] = dArr6[nextInt] - (1.0E-6d * n2);
                this.debug.println(2, "shifting done");
            }
        }
        this.means = dArr2;
    }

    @Override // net.jkernelmachines.density.DensityFunction
    public double valueOf(double[] dArr) {
        double d = Double.POSITIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < this.K; i2++) {
            double d2p2 = VectorOperations.d2p2(dArr, this.means[i2]);
            if (d2p2 < d) {
                d = d2p2;
                i = i2;
            }
        }
        return i;
    }

    public double[] distanceToMean(double[] dArr) {
        double[] dArr2 = new double[this.K];
        for (int i = 0; i < this.K; i++) {
            dArr2[i] = VectorOperations.d2p2(dArr, this.means[i]);
        }
        return dArr2;
    }
}
