package net.jkernelmachines.classifier;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import net.jkernelmachines.kernel.Kernel;
import net.jkernelmachines.threading.ThreadedMatrixOperator;
import net.jkernelmachines.type.TrainingSample;
import net.jkernelmachines.type.TrainingSampleStream;
import net.jkernelmachines.util.DebugPrinter;

/* loaded from: input_file:net/jkernelmachines/classifier/LaSVM.class */
public final class LaSVM<T> implements KernelSVM<T>, Serializable, OnlineClassifier<T> {
    private static final long serialVersionUID = -831288193185967121L;
    private Kernel<T> kernel;
    private List<TrainingSample<T>> tlist;
    private T[] tarray;
    private boolean[] S;
    private double[] alphas;
    private int[] y;
    private double[] g;
    private double[] Cmin;
    private double[] Cmax;
    private int imin;
    private int imax;
    private double gmin;
    private double gmax;
    private double[][] kmatrix;
    private double[] kmaxmin;
    private LinkedList<Integer> trainQueue;
    private static final double tau = 1.0E-15d;
    private static final int initSampling = 5;
    private boolean minmaxFlag = false;
    private double b = 0.0d;
    private double C = 1.0d;
    private int E = initSampling;
    transient DebugPrinter debug = new DebugPrinter();

    public LaSVM(Kernel<T> kernel) {
        this.kernel = kernel;
    }

    @Override // net.jkernelmachines.classifier.OnlineClassifier
    @Deprecated
    public void train(TrainingSample<T> trainingSample) {
        if (this.tlist == null) {
            this.tlist = new ArrayList();
            this.tlist.add(trainingSample);
            init();
        } else {
            if (!this.tlist.contains(trainingSample)) {
                this.tlist.add(trainingSample);
            }
            this.kmatrix = this.kernel.getKernelMatrix(this.tlist);
            int size = this.tlist.size() - 1;
            this.S = Arrays.copyOf(this.S, this.tlist.size());
            this.S[size] = true;
            this.alphas = Arrays.copyOf(this.alphas, this.tlist.size());
            this.y = Arrays.copyOf(this.y, this.tlist.size());
            this.y[size] = trainingSample.label;
            this.g = Arrays.copyOf(this.g, this.tlist.size());
            this.g[size] = this.y[size];
            for (int i = 0; i < this.alphas.length; i++) {
                double[] dArr = this.g;
                dArr[size] = dArr[size] - (this.alphas[i] * this.kmatrix[size][i]);
            }
            this.kmaxmin = Arrays.copyOf(this.kmaxmin, this.tlist.size());
            this.imin = -1;
            this.imax = -1;
            this.minmaxFlag = false;
            this.Cmin = Arrays.copyOf(this.Cmin, this.tlist.size());
            this.Cmax = Arrays.copyOf(this.Cmax, this.tlist.size());
            this.Cmin[this.y.length - 1] = Math.min(this.C * this.y[this.y.length - 1], 0.0d);
            this.Cmax[this.y.length - 1] = Math.max(this.C * this.y[this.y.length - 1], 0.0d);
        }
        train();
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<T>> list) {
        this.tlist = new ArrayList();
        this.tlist.addAll(list);
        if (this.tlist.isEmpty()) {
            return;
        }
        int i = this.tlist.get(0).label;
        boolean z = true;
        Iterator<TrainingSample<T>> it = this.tlist.iterator();
        while (it.hasNext()) {
            if (it.next().label != i) {
                z = false;
            }
        }
        if (!z) {
            init();
            train();
            return;
        }
        this.tarray = (T[]) new Object[this.tlist.size()];
        for (int i2 = 0; i2 < this.tarray.length; i2++) {
            this.tarray[i2] = this.tlist.get(i2).sample;
        }
        this.alphas = new double[this.tlist.size()];
        Arrays.fill(this.alphas, i * this.C);
        this.S = new boolean[this.tlist.size()];
        Arrays.fill(this.S, true);
        this.y = new int[this.tlist.size()];
        Arrays.fill(this.y, i);
    }

    @Override // net.jkernelmachines.classifier.OnlineClassifier
    public void onlineTrain(TrainingSampleStream<T> trainingSampleStream) {
        while (true) {
            TrainingSample<T> nextSample = trainingSampleStream.nextSample();
            if (nextSample == null) {
                return;
            } else {
                train(nextSample);
            }
        }
    }

    private void train() {
        this.trainQueue = new LinkedList<>();
        for (int i = 0; i < this.E; i++) {
            for (int i2 = 0; i2 < this.tlist.size(); i2++) {
                this.trainQueue.add(Integer.valueOf(i2));
            }
            while (!this.trainQueue.isEmpty()) {
                process(this.trainQueue.poll().intValue());
                reprocess();
            }
        }
        int i3 = 100000;
        while (optim(-1, -1)) {
            int i4 = i3;
            i3--;
            if (i4 <= 0) {
                break;
            }
        }
        if (i3 == 0) {
            this.debug.println(2, "*** lasvm : too much reprocess.");
        }
        reprocess();
        for (int i5 = 0; i5 < this.S.length; i5++) {
            if (this.alphas[i5] == 0.0d) {
                this.S[i5] = false;
            }
            if (!this.S[i5]) {
                this.alphas[i5] = 0.0d;
            }
        }
        minmax();
        this.b = (this.gmax + this.gmin) / 2.0d;
        this.tarray = (T[]) new Object[this.tlist.size()];
        for (int i6 = 0; i6 < this.tarray.length; i6++) {
            this.tarray[i6] = this.tlist.get(i6).sample;
        }
        this.kmatrix = (double[][]) null;
    }

    public void retrain() {
        this.kmatrix = this.kernel.getKernelMatrix(this.tlist);
        final double[] dArr = new double[this.g.length];
        new ThreadedMatrixOperator() { // from class: net.jkernelmachines.classifier.LaSVM.1
            @Override // net.jkernelmachines.threading.ThreadedMatrixOperator
            public void doLines(double[][] dArr2, int i, int i2) {
                for (int i3 = i; i3 < i2; i3++) {
                    if (LaSVM.this.S[i3]) {
                        dArr[i3] = LaSVM.this.y[i3];
                        for (int i4 = 0; i4 < dArr2[i3].length; i4++) {
                            double[] dArr3 = dArr;
                            int i5 = i3;
                            dArr3[i5] = dArr3[i5] - (LaSVM.this.alphas[i4] * LaSVM.this.kmatrix[i3][i4]);
                        }
                    }
                }
            }
        }.getMatrix(this.kmatrix);
        this.g = dArr;
        int i = 100000;
        while (optim(-1, -1)) {
            int i2 = i;
            i--;
            if (i2 <= 0) {
                break;
            }
        }
        if (i == 0) {
            this.debug.println(2, "*** lasvm : too much reprocess.");
        }
        reprocess();
    }

    private void init() {
        this.S = new boolean[this.tlist.size()];
        Arrays.fill(this.S, false);
        this.alphas = new double[this.tlist.size()];
        this.y = new int[this.tlist.size()];
        this.g = new double[this.tlist.size()];
        this.kmaxmin = new double[this.tlist.size()];
        this.imin = -1;
        this.imax = -1;
        this.minmaxFlag = false;
        this.Cmin = new double[this.tlist.size()];
        this.Cmax = new double[this.tlist.size()];
        for (int i = 0; i < this.Cmin.length; i++) {
            this.y[i] = this.tlist.get(i).label;
            this.Cmin[i] = Math.min(this.C * this.y[i], 0.0d);
            this.Cmax[i] = Math.max(this.C * this.y[i], 0.0d);
        }
        this.kmatrix = this.kernel.getKernelMatrix(this.tlist);
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < this.alphas.length; i4++) {
            if (this.y[i4] == 1 && i2 < initSampling) {
                this.S[i4] = true;
                this.g[i4] = 1.0d;
                i2++;
            }
            if (this.y[i4] == -1 && i3 < initSampling) {
                this.S[i4] = true;
                this.g[i4] = -1.0d;
                i3++;
            }
            if (i2 > initSampling && i3 > initSampling) {
                return;
            }
        }
    }

    private final void minmax() {
        if (this.minmaxFlag) {
            return;
        }
        this.imin = -1;
        this.imax = -1;
        this.gmin = Double.POSITIVE_INFINITY;
        this.gmax = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.S.length; i++) {
            if (this.S[i]) {
                double d = this.alphas[i];
                double d2 = this.g[i];
                if (d > this.Cmin[i] && d2 < this.gmin) {
                    this.gmin = d2;
                    this.imin = i;
                }
                if (d < this.Cmax[i] && d2 > this.gmax) {
                    this.gmax = d2;
                    this.imax = i;
                }
            }
        }
        this.minmaxFlag = true;
    }

    private boolean optim(int i, int i2) {
        minmax();
        if (i < 0) {
            i = this.imin;
        }
        if (i2 < 0) {
            i2 = this.imax;
        }
        if (i < 0 || i2 < 0) {
            return false;
        }
        this.gmin = this.g[i];
        this.gmax = this.g[i2];
        double d = this.g[i2] - this.g[i];
        if (d < tau) {
            return false;
        }
        double min = Math.min(this.alphas[i] - this.Cmin[i], this.Cmax[i2] - this.alphas[i2]);
        if (min == 0.0d) {
            return false;
        }
        double min2 = Math.min(d / ((this.kmatrix[i][i] + this.kmatrix[i2][i2]) - (2.0d * this.kmatrix[i][i2])), min);
        double[] dArr = this.alphas;
        int i3 = i2;
        dArr[i3] = dArr[i3] + min2;
        double[] dArr2 = this.alphas;
        int i4 = i;
        dArr2[i4] = dArr2[i4] - min2;
        for (int i5 = 0; i5 < this.S.length; i5++) {
            this.kmaxmin[i5] = this.kmatrix[i2][i5] - this.kmatrix[i][i5];
        }
        for (int i6 = 0; i6 < this.S.length; i6++) {
            if (this.S[i6]) {
                double[] dArr3 = this.g;
                int i7 = i6;
                dArr3[i7] = dArr3[i7] - (min2 * this.kmaxmin[i6]);
            }
        }
        this.minmaxFlag = false;
        return true;
    }

    private boolean process(int i) {
        if (this.S[i]) {
            return false;
        }
        if (this.y[i] != 1 && this.y[i] != -1) {
            return false;
        }
        this.alphas[i] = 0.0d;
        double d = this.y[i];
        for (int i2 = 0; i2 < this.S.length; i2++) {
            if (this.S[i2]) {
                d -= this.alphas[i2] * this.kmatrix[i][i2];
            }
        }
        minmax();
        if (this.gmin < this.gmax) {
            if (this.Cmin[i] >= 0.0d && d < this.gmin) {
                return false;
            }
            if (this.Cmax[i] <= 0.0d && d > this.gmax) {
                return false;
            }
        }
        this.S[i] = true;
        this.g[i] = d;
        this.minmaxFlag = false;
        if (this.Cmin[i] >= 0.0d) {
            optim(-1, i);
            return true;
        }
        optim(i, -1);
        return true;
    }

    private boolean reprocess() {
        boolean optim = optim(-1, -1);
        minmax();
        for (int i = 0; i < this.S.length; i++) {
            if (this.S[i] && this.alphas[i] == 0.0d) {
                if (this.y[i] == -1) {
                    if (this.g[i] >= this.gmax) {
                        this.S[i] = false;
                    }
                } else if (this.g[i] <= this.gmin) {
                    this.S[i] = false;
                }
            }
        }
        return optim;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public final double valueOf(T t) {
        if (this.S == null) {
            return 0.0d;
        }
        double d = this.b;
        double[] dArr = new double[this.S.length];
        for (int i = 0; i < this.alphas.length; i++) {
            if (this.S[i]) {
                dArr[i] = this.kernel.valueOf(this.tarray[i], t);
            }
        }
        for (int i2 = 0; i2 < this.S.length; i2++) {
            if (this.S[i2]) {
                d += this.alphas[i2] * dArr[i2];
            }
        }
        return d;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double getC() {
        return this.C;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setC(double d) {
        this.C = d;
    }

    public int getE() {
        return this.E;
    }

    public void setE(int i) {
        this.E = i;
    }

    public double getTau() {
        return tau;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double[] getAlphas() {
        double[] dArr = new double[this.alphas.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.alphas[i] * this.y[i];
        }
        return dArr;
    }

    public double getB() {
        return this.b;
    }

    public void setB(double d) {
        this.b = d;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public Kernel<T> getKernel() {
        return this.kernel;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setKernel(Kernel<T> kernel) {
        this.kernel = kernel;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    /* renamed from: copy */
    public LaSVM<T> copy2() throws CloneNotSupportedException {
        return (LaSVM) super.clone();
    }
}
