package net.jkernelmachines.classifier;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import net.jkernelmachines.density.DoubleKMeans;
import net.jkernelmachines.type.TrainingSample;
import net.jkernelmachines.util.DebugPrinter;
import net.jkernelmachines.util.algebra.MatrixVectorOperations;
import net.jkernelmachines.util.algebra.VectorOperations;

/* loaded from: input_file:net/jkernelmachines/classifier/DoubleLLSVM.class */
public class DoubleLLSVM implements Classifier<double[]> {
    DoubleKMeans km;
    double[][] W;
    double[] b;
    int K = 32;
    int E = 10;
    long t0 = 100;
    int skip = 10;
    double C = 1.0d;
    int nn = 2;
    DebugPrinter debug = new DebugPrinter();

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<double[]>> list) {
        this.nn = Math.min(this.nn, this.K);
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<TrainingSample<double[]>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().sample);
        }
        this.km = new DoubleKMeans(this.K);
        this.km.train((List<double[]>) arrayList);
        this.debug.println(1, "KM trained");
        arrayList.clear();
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(likelihood(list.get(i).sample));
            linkedList.add(Integer.valueOf(i));
        }
        this.debug.println(1, "likelihood computed");
        this.W = new double[this.K][list.get(0).sample.length];
        this.b = new double[this.K];
        int i2 = this.skip;
        long j = this.t0;
        double size = 1.0d / (this.C * list.size());
        for (int i3 = 0; i3 < this.E; i3++) {
            Collections.shuffle(linkedList);
            Iterator it2 = linkedList.iterator();
            while (it2.hasNext()) {
                int intValue = ((Integer) it2.next()).intValue();
                double[] dArr = (double[]) arrayList.get(intValue);
                double[] dArr2 = list.get(intValue).sample;
                int i4 = list.get(intValue).label;
                if (1.0d - (i4 * (VectorOperations.dot(this.b, dArr) + VectorOperations.dot(dArr, MatrixVectorOperations.rMul(this.W, dArr2)))) > 0.0d) {
                    double[][] outer = MatrixVectorOperations.outer(dArr, dArr2);
                    for (int i5 = 0; i5 < dArr.length; i5++) {
                        for (int i6 = 0; i6 < dArr2.length; i6++) {
                            double[] dArr3 = this.W[i5];
                            int i7 = i6;
                            dArr3[i7] = dArr3[i7] + ((1.0d / (size * j)) * i4 * outer[i5][i6]);
                        }
                    }
                }
                VectorOperations.addi(this.b, this.b, (1.0d / (size * j)) * i4, dArr);
                i2--;
                if (i2 < 0) {
                    i2 = this.skip;
                    for (int i8 = 0; i8 < dArr.length; i8++) {
                        for (int i9 = 0; i9 < dArr2.length; i9++) {
                            double[] dArr4 = this.W[i8];
                            int i10 = i9;
                            dArr4[i10] = dArr4[i10] * (1 - (this.skip / j));
                        }
                    }
                }
                j++;
            }
            if (i2 > 0) {
                for (int i11 = 0; i11 < this.W.length; i11++) {
                    for (int i12 = 0; i12 < this.W[i11].length; i12++) {
                        double[] dArr5 = this.W[i11];
                        int i13 = i12;
                        dArr5[i13] = dArr5[i13] * (1 - ((this.skip - i2) / j));
                    }
                }
            }
            this.debug.println(1, "epoch " + i3 + " finished");
        }
        this.debug.println(2, "W: " + Arrays.deepToString(this.W));
        this.debug.println(2, "b: " + Arrays.toString(this.b));
    }

    private double[] likelihood(double[] dArr) {
        double[] distanceToMean = this.km.distanceToMean(dArr);
        double[] copyOf = Arrays.copyOf(distanceToMean, distanceToMean.length);
        Arrays.sort(copyOf);
        double d = copyOf[this.nn - 1];
        double d2 = 0.0d;
        for (int i = 0; i < this.K; i++) {
            if (distanceToMean[i] <= d) {
                distanceToMean[i] = 1.0d / (1.0d + distanceToMean[i]);
            } else {
                distanceToMean[i] = 0.0d;
            }
            d2 += distanceToMean[i];
        }
        VectorOperations.mul(distanceToMean, 1.0d / d2);
        return distanceToMean;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public double valueOf(double[] dArr) {
        double[] likelihood = likelihood(dArr);
        return VectorOperations.dot(likelihood, MatrixVectorOperations.rMul(this.W, dArr)) + VectorOperations.dot(likelihood, this.b);
    }

    @Override // net.jkernelmachines.classifier.Classifier
    /* renamed from: copy */
    public Classifier<double[]> copy2() throws CloneNotSupportedException {
        return (DoubleLLSVM) clone();
    }

    public int getK() {
        return this.K;
    }

    public void setK(int i) {
        this.K = i;
    }

    public int getE() {
        return this.E;
    }

    public void setE(int i) {
        this.E = i;
    }

    public double getC() {
        return this.C;
    }

    public void setC(double d) {
        this.C = d;
    }

    public int getNn() {
        return this.nn;
    }

    public void setNn(int i) {
        this.nn = i;
    }

    public double[][] getW() {
        return this.W;
    }

    public double[] getB() {
        return this.b;
    }
}
