package net.jkernelmachines.classifier;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import net.jkernelmachines.density.SMODensity;
import net.jkernelmachines.kernel.Kernel;
import net.jkernelmachines.type.TrainingSample;
import net.jkernelmachines.util.DebugPrinter;

/* loaded from: input_file:net/jkernelmachines/classifier/SMOSVM.class */
public class SMOSVM<T> implements KernelSVM<T>, Serializable, Cloneable {
    private static final long serialVersionUID = -1224235635423748229L;
    private double[] alphay;
    private double[] alpha;
    private ArrayList<TrainingSample<T>> ts;
    private int size;
    private Kernel<T> kernel;
    private double[][] kcache;
    private double[] ecache;
    private double b;
    private double C = 1.0d;
    private double eps = 1.0E-15d;
    private double tolerance = 1.0E-15d;
    DebugPrinter debug = new DebugPrinter();
    private Random ran = new Random(System.currentTimeMillis());

    public SMOSVM(Kernel<T> kernel) {
        this.kernel = kernel;
    }

    public void train(TrainingSample<T> trainingSample) {
        if (this.ts == null) {
            this.ts = new ArrayList<>();
            this.b = 0.0d;
        }
        this.ts.add(trainingSample);
        double[] copyOf = Arrays.copyOf(this.alpha, this.alpha.length + 1);
        copyOf[this.alpha.length] = 0.0d;
        this.alpha = (double[]) copyOf.clone();
        this.size = this.ts.size();
        train();
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<T>> list) {
        this.ts = new ArrayList<>();
        this.ts.addAll(list);
        this.alpha = new double[this.ts.size()];
        this.b = 0.0d;
        this.size = this.ts.size();
        train();
    }

    public void retrain() {
        train();
    }

    private void train() {
        long currentTimeMillis = System.currentTimeMillis();
        this.alphay = new double[this.size];
        if (this.alpha != null) {
            for (int i = 0; i < this.size; i++) {
                this.alphay[i] = this.ts.get(i).label * this.alpha[i];
            }
        } else {
            Arrays.fill(this.alphay, 0.0d);
        }
        this.ecache = new double[this.size];
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < this.size; i4++) {
            int i5 = this.ts.get(i4).label;
            if (i5 == 1) {
                i2++;
            } else if (i5 == -1) {
                i3++;
            }
        }
        if (i2 > 0 && i3 == 0) {
            this.debug.println(1, "exemple positifs uniquement, SMODensity choisi");
            SMODensity sMODensity = new SMODensity(this.kernel);
            ArrayList arrayList = new ArrayList();
            Iterator<TrainingSample<T>> it = this.ts.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().sample);
            }
            sMODensity.train((List) arrayList);
            this.alpha = sMODensity.getAlphas();
            this.alphay = (double[]) this.alpha.clone();
            return;
        }
        if (i2 == 0 && i3 > 0) {
            this.debug.println(1, "exemple négatifs uniquement, SMODensity choisi");
            SMODensity sMODensity2 = new SMODensity(this.kernel);
            ArrayList arrayList2 = new ArrayList();
            Iterator<TrainingSample<T>> it2 = this.ts.iterator();
            while (it2.hasNext()) {
                arrayList2.add(it2.next().sample);
            }
            sMODensity2.train((List) arrayList2);
            this.alpha = sMODensity2.getAlphas();
            for (int i6 = 0; i6 < this.alpha.length; i6++) {
                this.alphay[i6] = (-1.0d) * this.alpha[i6];
            }
            return;
        }
        this.debug.println(3, "building cache.");
        this.kcache = this.kernel.getKernelMatrix(this.ts);
        this.debug.println(4, "kcache size : " + this.kcache.length);
        this.debug.println(3, "kcache built.");
        int i7 = 0;
        boolean z = true;
        for (int i8 = 0; i8 < this.size; i8++) {
            double d = 0.0d;
            for (int i9 = 0; i9 < this.size; i9++) {
                if (this.alphay[i9] != 0.0d) {
                    d += this.alphay[i9] * this.kcache[i8][i9];
                }
            }
            this.ecache[i8] = (d - this.b) - this.ts.get(i8).label;
        }
        this.debug.println(4, "smotrain : ecache=" + Arrays.toString(this.ecache));
        long currentTimeMillis2 = System.currentTimeMillis();
        int i10 = 0;
        while (true) {
            if (i7 <= 0 && !z) {
                break;
            }
            i7 = 0;
            if (z) {
                for (int i11 = 0; i11 < this.size; i11++) {
                    if (examiner(i11)) {
                        i7++;
                    }
                }
            } else {
                for (int i12 = 0; i12 < this.size; i12++) {
                    if (this.alpha[i12] > this.eps && this.alpha[i12] < this.C - this.eps && examiner(i12)) {
                        i7++;
                    }
                }
            }
            if (z) {
                z = false;
            } else if (i7 == 0) {
                z = true;
            }
            i10++;
            if (i10 > 1000000) {
                this.debug.println(1, "Too many iterations...");
                break;
            } else if (i10 % 10000 == 0) {
                this.debug.println(1, "iteration : " + i10);
            }
        }
        if (DebugPrinter.DEBUG_LEVEL >= 4) {
            this.debug.println(3, "smotrain : after train ecache=" + Arrays.toString(this.ecache));
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            for (int i13 = 0; i13 < this.size; i13++) {
                d2 += this.ts.get(i13).label * this.ecache[i13];
                d3 += this.ts.get(i13).label * this.alpha[i13];
                d4 += this.alpha[i13];
            }
            this.debug.println(4, "smotrain : after training errSum=" + (d2 / this.size) + " alpySum=" + d3 + " alpSum=" + d4);
        }
        this.alphay = new double[this.alpha.length];
        for (int i14 = 0; i14 < this.alpha.length; i14++) {
            this.alphay[i14] = this.alpha[i14] * this.ts.get(i14).label;
        }
        this.debug.println(3, "training done in " + i10 + " iterations timeCache=" + (currentTimeMillis2 - currentTimeMillis) + " timeTrain=" + (System.currentTimeMillis() - currentTimeMillis2));
        this.kcache = (double[][]) null;
    }

    private boolean examiner(int i) {
        double d = this.ts.get(i).label;
        double d2 = this.alpha[i];
        double d3 = this.ecache[i];
        double d4 = d * d3;
        if ((d4 >= (-this.tolerance) || d2 >= this.C - this.eps) && (d4 <= this.tolerance || d2 <= this.eps)) {
            return false;
        }
        int i2 = this.size;
        double d5 = 0.0d;
        for (int i3 = 0; i3 < this.size; i3++) {
            if (this.alpha[i3] > this.eps && this.alpha[i3] < this.C - this.eps) {
                double abs = Math.abs(d3 - this.ecache[i3]);
                if (abs > d5) {
                    d5 = abs;
                    i2 = i3;
                }
            }
        }
        if (i2 < this.size && optimiser(i2, i)) {
            return true;
        }
        int nextInt = this.ran.nextInt(this.size);
        for (int i4 = nextInt; i4 < nextInt + this.size; i4++) {
            int i5 = i4 % this.size;
            if (this.alpha[i5] > this.eps && this.alpha[i5] < this.C - this.eps && optimiser(i5, i)) {
                return true;
            }
        }
        int nextInt2 = new Random().nextInt(this.size);
        for (int i6 = nextInt2; i6 < nextInt2 + this.size; i6++) {
            if (optimiser(i6 % this.size, i)) {
                return true;
            }
        }
        return false;
    }

    private boolean optimiser(int i, int i2) {
        double max;
        double min;
        double d;
        if (i == i2) {
            return false;
        }
        double d2 = this.alpha[i];
        double d3 = this.ts.get(i).label;
        double d4 = this.ecache[i];
        double d5 = this.alpha[i2];
        double d6 = this.ts.get(i2).label;
        double d7 = this.ecache[i2];
        double d8 = d3 * d6;
        if (d3 == d6) {
            max = Math.max(0.0d, (d5 + d2) - this.C);
            min = Math.min(this.C, d5 + d2);
        } else {
            max = Math.max(0.0d, d5 - d2);
            min = Math.min(this.C, (this.C + d5) - d2);
        }
        if (max == min) {
            return false;
        }
        double d9 = this.kcache[i][i];
        double d10 = this.kcache[i2][i2];
        double d11 = this.kcache[i][i2];
        double d12 = ((2.0d * d11) - d9) - d10;
        if (d12 < 0.0d) {
            d = d5 + ((d6 * (d7 - d4)) / d12);
            if (d < max) {
                d = max;
            } else if (d > min) {
                d = min;
            }
        } else {
            double frLimite = frLimite(i, i2, max);
            double frLimite2 = frLimite(i, i2, min);
            d = frLimite > frLimite2 + this.eps ? max : frLimite < frLimite2 - this.eps ? min : d5;
        }
        if (Math.abs(d - d5) < this.eps * (d + d5 + this.eps)) {
            return false;
        }
        double d13 = d2 + (d8 * (d5 - d));
        if (d13 < 0.0d) {
            d -= d8 * d13;
            d13 = 0.0d;
        } else if (d13 > this.C) {
            d -= d8 * (d13 - this.C);
            d13 = this.C;
        }
        if (d > this.C - this.eps) {
            d = this.C;
            this.debug.println(4, "svm : i1=" + i + " i2=" + i2 + " a2nouv = C !!! a1nouv=" + d13 + " a1prec=" + d2 + " a2prec=" + d5 + " eta=" + d12 + " k12=" + d11 + " k11=" + d9 + " k22=" + d10 + " L=" + max + " H=" + min + " y1=" + d3 + " y2=" + d6 + " s=" + d8 + " E1=" + d4 + " e2=" + d7);
        } else if (d <= this.eps) {
            d = 0.0d;
        }
        if (d13 < this.eps) {
            d13 = 0.0d;
        } else if (d13 > this.C - this.eps) {
            d13 = this.C;
            this.debug.println(4, "svm : i1=" + i + " i2=" + i2 + " a1nouv = C !!! a2nouv=" + d + " a1prec=" + d2 + " a2prec=" + d5 + " eta=" + d12 + " k12=" + d11 + " k11=" + d9 + " k22=" + d10 + " L=" + max + " H=" + min + " y1=" + d3 + " y2=" + d6 + " s=" + d8 + " E1=" + d4 + " e2=" + d7);
        }
        double d14 = (d13 <= this.eps || d13 >= this.C - this.eps) ? (d <= this.eps || d >= this.C - this.eps) ? ((d4 + d7) / 2.0d) + (d3 * (d13 - d2) * (d9 + d11)) + (d6 * (d - d5) * (d11 + d10)) : d7 + (d3 * (d13 - d2) * d11) + (d6 * (d - d5) * d10) : d4 + (d3 * (d13 - d2) * d9) + (d6 * (d - d5) * d11);
        this.b += d14;
        double d15 = d3 * (d13 - d2);
        double d16 = d6 * (d - d5);
        for (int i3 = 0; i3 < this.size; i3++) {
            double[] dArr = this.ecache;
            int i4 = i3;
            dArr[i4] = dArr[i4] + (((d15 * this.kcache[i][i3]) + (d16 * this.kcache[i2][i3])) - d14);
        }
        this.alpha[i] = d13;
        this.alpha[i2] = d;
        return true;
    }

    private double frLimite(int i, int i2, double d) {
        double d2 = this.ts.get(i).label;
        double d3 = this.ts.get(i2).label;
        double d4 = this.alpha[i] + (d2 * d3 * (this.alpha[i2] - d));
        double d5 = ((-d2) * d4) / 2.0d;
        double d6 = ((-d3) * d) / 2.0d;
        double d7 = d4 + d;
        for (int i3 = 0; i3 < this.size; i3++) {
            if (this.alpha[i3] > this.eps) {
                d7 = d7 + (d5 * this.ts.get(i3).label * this.kcache[i][i3]) + (d6 * this.ts.get(i3).label * this.kcache[i2][i3]);
            }
        }
        return d7;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public double valueOf(T t) {
        double d = 0.0d;
        for (int i = 0; i < this.size; i++) {
            if (this.alphay[i] != 0.0d) {
                d += this.alphay[i] * this.kernel.valueOf(this.ts.get(i).sample, t);
            }
        }
        return d - this.b;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double[] getAlphas() {
        return this.alpha;
    }

    public double getB() {
        return this.b;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double getC() {
        return this.C;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setC(double d) {
        this.C = d;
    }

    public ArrayList<TrainingSample<T>> getTrainingSet() {
        return this.ts;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setKernel(Kernel<T> kernel) {
        this.kernel = kernel;
    }

    public void setAlphas(double[] dArr) {
        this.alpha = dArr;
    }

    public void setTrain(ArrayList<TrainingSample<T>> arrayList) {
        this.ts = new ArrayList<>(arrayList);
    }

    @Override // net.jkernelmachines.classifier.Classifier
    /* renamed from: copy */
    public SMOSVM<T> copy2() throws CloneNotSupportedException {
        return (SMOSVM) super.clone();
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public Kernel<T> getKernel() {
        return this.kernel;
    }
}
