package net.jkernelmachines.classifier;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import net.jkernelmachines.kernel.typed.DoubleLinear;
import net.jkernelmachines.type.TrainingSample;
import net.jkernelmachines.type.TrainingSampleStream;
import net.jkernelmachines.util.DebugPrinter;

/* loaded from: input_file:net/jkernelmachines/classifier/DoublePegasosSVM.class */
public class DoublePegasosSVM implements OnlineClassifier<double[]>, Serializable {
    private static final long serialVersionUID = 5289136605543751554L;
    private List<TrainingSample<double[]>> tList;
    private double[] w;
    private DoubleLinear kernel = new DoubleLinear();
    private double b = 0.0d;
    int T = 100000;
    int k = 10;
    double lambda = 0.001d;
    double t0 = 100.0d;
    boolean bias = true;
    double C = 1.0d;
    boolean hasC = false;
    DebugPrinter debug = new DebugPrinter();
    private int i = 0;

    @Override // net.jkernelmachines.classifier.OnlineClassifier
    public void train(TrainingSample<double[]> trainingSample) {
        if (this.w == null) {
            this.w = new double[trainingSample.sample.length];
            this.b = 0.0d;
        }
        if (this.tList == null) {
            this.tList = new ArrayList();
        }
        this.tList.add(trainingSample);
        __train();
    }

    @Override // net.jkernelmachines.classifier.OnlineClassifier
    public void onlineTrain(TrainingSampleStream<double[]> trainingSampleStream) {
        TrainingSample<double[]> nextSample = trainingSampleStream.nextSample();
        if (nextSample == null) {
            return;
        }
        this.w = new double[nextSample.sample.length];
        Arrays.fill(this.w, 0.0d);
        this.b = 0.0d;
        train(nextSample);
        while (true) {
            TrainingSample<double[]> nextSample2 = trainingSampleStream.nextSample();
            if (nextSample2 == null) {
                return;
            } else {
                train(nextSample2);
            }
        }
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<double[]>> list) {
        if (this.k > list.size()) {
            this.k = list.size();
        }
        this.tList = list;
        this.w = new double[this.tList.get(0).sample.length];
        Arrays.fill(this.w, 0.0d);
        this.b = 0.0d;
        if (this.hasC) {
            this.lambda = 1.0d / (this.C * this.tList.size());
        }
        this.debug.println(1, "begin training");
        long currentTimeMillis = System.currentTimeMillis();
        ArrayList arrayList = new ArrayList();
        long size = this.tList.size();
        for (int i = 0; i < size; i++) {
            arrayList.add(Integer.valueOf(i));
        }
        for (int i2 = 0; i2 < this.T; i2++) {
            __train();
        }
        this.debug.println(2, "");
        this.debug.println(1, "done in " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
        this.debug.println(3, "w : " + Arrays.toString(this.w) + " b : " + this.b);
    }

    private void __train() {
        int length = this.tList.get(0).sample.length;
        ArrayList arrayList = new ArrayList();
        long size = this.tList.size();
        for (int i = 0; i < size; i++) {
            arrayList.add(Integer.valueOf(i));
        }
        Collections.shuffle(arrayList);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.addAll(arrayList.subList(0, Math.min(this.k, arrayList.size())));
        Iterator it = arrayList2.iterator();
        while (it.hasNext()) {
            if ((this.kernel.valueOf(this.w, this.tList.get(((Integer) it.next()).intValue()).sample) - this.b) * this.tList.get(r0.intValue()).label > 1.0d) {
                it.remove();
            }
        }
        double d = 1.0d / (this.lambda * (this.i + this.t0));
        double[] dArr = new double[this.w.length];
        double d2 = 1.0d - (d * this.lambda);
        for (int i2 = 0; i2 < length; i2++) {
            dArr[i2] = d2 * this.w[i2];
        }
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            TrainingSample<double[]> trainingSample = this.tList.get(((Integer) it2.next()).intValue());
            for (int i3 = 0; i3 < length; i3++) {
                if (trainingSample.sample[i3] != 0.0d) {
                    int i4 = i3;
                    dArr[i4] = dArr[i4] + ((d / this.k) * trainingSample.label * trainingSample.sample[i3]);
                }
            }
        }
        double d3 = 0.0d;
        if (this.bias) {
            while (arrayList2.iterator().hasNext()) {
                d3 += this.tList.get(((Integer) r0.next()).intValue()).label;
            }
        }
        double sqrt = (1.0d / Math.sqrt(this.lambda)) / Math.sqrt(this.kernel.valueOf(dArr, dArr));
        if (sqrt > 1.0d) {
            sqrt = 1.0d;
        }
        double[] dArr2 = (double[]) dArr.clone();
        for (int i5 = 0; i5 < length; i5++) {
            dArr2[i5] = dArr[i5] * sqrt;
        }
        this.w = dArr2;
        if (this.bias) {
            this.b = sqrt * (((1.0d - (d * this.lambda)) * this.b) - ((d / this.k) * d3));
        } else {
            this.b = 0.0d;
        }
        this.debug.println(4, "w : " + Arrays.toString(this.w) + " b : " + this.b);
        if (this.T <= 20 || this.i % (this.T / 20) != 0) {
            return;
        }
        this.debug.print(2, ".");
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public double valueOf(double[] dArr) {
        return this.kernel.valueOf(this.w, dArr) - this.b;
    }

    public int getT() {
        return this.T;
    }

    public void setT(int i) {
        this.T = i;
    }

    public int getK() {
        return this.k;
    }

    public void setK(int i) {
        this.k = i;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public double[] getW() {
        return this.w;
    }

    public void setW(double[] dArr) {
        this.w = dArr;
    }

    public double getB() {
        return this.b;
    }

    public void setB(double d) {
        this.b = d;
    }

    public boolean isBias() {
        return this.bias;
    }

    public void setBias(boolean z) {
        this.bias = z;
    }

    public double getT0() {
        return this.t0;
    }

    public void setT0(double d) {
        this.t0 = d;
    }

    public void setC(double d) {
        this.hasC = true;
        this.C = d;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    /* renamed from: copy */
    public DoublePegasosSVM copy2() throws CloneNotSupportedException {
        return (DoublePegasosSVM) super.clone();
    }

    public double getC() {
        if (this.hasC) {
            return this.C;
        }
        return 0.0d;
    }
}
