package net.jkernelmachines.classifier;

import java.util.Arrays;
import java.util.List;
import java.util.Vector;
import net.jkernelmachines.kernel.Kernel;
import net.jkernelmachines.kernel.typed.DoubleLinear;
import net.jkernelmachines.type.TrainingSample;

/* loaded from: input_file:net/jkernelmachines/classifier/DoubleSGDQN.class */
public class DoubleSGDQN implements Classifier<double[]> {
    public static final int HINGELOSS = 1;
    public static final int SMOOTHHINGELOSS = 2;
    public static final int SQUAREDHINGELOSS = 3;
    public static final int LOGLOSS = 10;
    public static final int LOGLOSSMARGIN = 11;
    private int count;
    private long t;
    double[] mean;
    double[] var;
    public static boolean VERBOSE = false;
    List<TrainingSample<double[]>> tlist;
    private int loss = 1;
    private double[] w = null;
    private double[] Bc = null;
    private int skip = 0;
    private long t0 = 0;
    private double lambda = 1.0E-4d;
    private int epochs = 5;
    double eps = 0.01d;
    boolean hasC = false;
    double C = 1.0d;
    private boolean normalize = false;
    private final Kernel<double[]> dot = new DoubleLinear();

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<double[]>> list) {
        this.tlist = new Vector();
        this.tlist.addAll(list);
        int size = this.tlist.size();
        if (this.hasC) {
            this.lambda = 1.0d / (this.C * this.tlist.size());
        }
        if (this.normalize) {
            this.mean = new double[this.tlist.get(0).sample.length];
            this.var = new double[this.mean.length];
            for (TrainingSample<double[]> trainingSample : this.tlist) {
                for (int i = 0; i < trainingSample.sample.length; i++) {
                    double[] dArr = this.mean;
                    int i2 = i;
                    dArr[i2] = dArr[i2] + trainingSample.sample[i];
                }
            }
            for (int i3 = 0; i3 < this.mean.length; i3++) {
                double[] dArr2 = this.mean;
                int i4 = i3;
                dArr2[i4] = dArr2[i4] / size;
            }
            if (VERBOSE) {
                System.out.println("||mean|| = " + this.dot.valueOf(this.mean, this.mean));
            }
            for (TrainingSample<double[]> trainingSample2 : this.tlist) {
                for (int i5 = 0; i5 < trainingSample2.sample.length; i5++) {
                    double d = trainingSample2.sample[i5] - this.mean[i5];
                    double[] dArr3 = this.var;
                    int i6 = i5;
                    dArr3[i6] = dArr3[i6] + (d * d);
                }
            }
            for (int i7 = 0; i7 < this.var.length; i7++) {
                this.var[i7] = this.var[i7] / size;
                if (this.var[i7] == 0.0d) {
                    this.var[i7] = 0.0d;
                } else {
                    this.var[i7] = Math.sqrt(1.0d / this.var[i7]);
                }
            }
            if (VERBOSE) {
                System.out.println("||var|| = " + this.dot.valueOf(this.var, this.var));
            }
            for (TrainingSample<double[]> trainingSample3 : this.tlist) {
                for (int i8 = 0; i8 < trainingSample3.sample.length; i8++) {
                    trainingSample3.sample[i8] = (trainingSample3.sample[i8] - this.mean[i8]) * this.var[i8];
                }
            }
        }
        long currentTimeMillis = System.currentTimeMillis();
        initSVM();
        if (VERBOSE) {
            System.out.println("dimension of w : " + this.w.length);
        }
        this.t0 = determineT0(0, this.tlist.size() / 10);
        if (VERBOSE) {
            System.out.println("t0 set to " + this.t0 + " (" + (System.currentTimeMillis() - currentTimeMillis) + " ms.)");
        }
        initSVM();
        calibrate(0, size - 1);
        long currentTimeMillis2 = System.currentTimeMillis();
        for (int i9 = 0; i9 < this.epochs; i9++) {
            train(0, this.tlist.size() - 1);
            if (VERBOSE) {
                long currentTimeMillis3 = System.currentTimeMillis();
                System.out.println("epoch " + i9 + " time : " + (currentTimeMillis3 - currentTimeMillis2) + " ms.");
                currentTimeMillis2 = currentTimeMillis3;
            }
        }
        if (VERBOSE) {
            System.out.println("done in " + (System.currentTimeMillis() - currentTimeMillis) + " ms.");
        }
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public double valueOf(double[] dArr) {
        if (!this.normalize) {
            return this.dot.valueOf(this.w, dArr);
        }
        double[] dArr2 = (double[]) dArr.clone();
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = (dArr2[i] - this.mean[i]) * this.var[i];
        }
        return this.dot.valueOf(this.w, dArr2);
    }

    private void initSVM() {
        this.w = new double[this.tlist.get(0).sample.length];
        this.Bc = new double[this.w.length];
        Arrays.fill(this.Bc, 1.0d / this.lambda);
        this.t = this.t0;
    }

    private void train(int i, int i2) {
        this.count = this.skip;
        boolean z = false;
        for (int i3 = i; i3 <= i2; i3++) {
            TrainingSample<double[]> trainingSample = this.tlist.get(i3);
            double[] dArr = trainingSample.sample;
            double d = trainingSample.label;
            double valueOf = d * this.dot.valueOf(this.w, dArr);
            double d2 = 1.0d / this.t;
            if (z) {
                if ((this.loss < 10 && valueOf < 1.0d) || this.loss >= 10) {
                    double[] dArr2 = (double[]) this.w.clone();
                    double dloss = dloss(valueOf);
                    for (int i4 = 0; i4 < this.w.length; i4++) {
                        double[] dArr3 = this.w;
                        int i5 = i4;
                        dArr3[i5] = dArr3[i5] + (dArr[i4] * this.Bc[i4] * d2 * dloss * d);
                    }
                    double dloss2 = dloss(d * this.dot.valueOf(this.w, dArr)) - dloss;
                    if (dloss2 != 0.0d) {
                        double[] computeRatio = computeRatio(dArr, this.lambda, dArr2, this.w, d * dloss2);
                        if (this.t > this.skip) {
                            combineAndClip(this.Bc, (this.t - this.skip) / (this.t + this.skip), computeRatio, (2.0d * this.skip) / (this.t + this.skip), 1.0d / (100.0d * this.lambda), 100.0d / this.lambda);
                        } else {
                            combineAndClip(this.Bc, this.t / (this.t + this.skip), computeRatio, this.skip / (this.t + this.skip), 1.0d / (100.0d * this.lambda), 100.0d / this.lambda);
                        }
                    }
                }
                z = false;
            } else {
                int i6 = this.count - 1;
                this.count = i6;
                if (i6 <= 0) {
                    for (int i7 = 0; i7 < this.w.length; i7++) {
                        double[] dArr4 = this.w;
                        int i8 = i7;
                        dArr4[i8] = dArr4[i8] + ((-this.skip) * this.lambda * d2 * this.Bc[i7]);
                    }
                    this.count = this.skip;
                    z = true;
                }
                if ((this.loss < 10 && valueOf < 1.0d) || this.loss >= 10) {
                    for (int i9 = 0; i9 < this.w.length; i9++) {
                        double[] dArr5 = this.w;
                        int i10 = i9;
                        dArr5[i10] = dArr5[i10] + (dArr[i9] * d2 * dloss(valueOf) * d * this.Bc[i9]);
                    }
                }
            }
            this.t++;
        }
    }

    private double test(int i, int i2) {
        double d = 0.0d;
        for (int i3 = i; i3 <= i2; i3++) {
            double valueOf = r0.label * this.dot.valueOf(this.w, this.tlist.get(i3).sample);
            if ((this.loss < 10 && valueOf < 1.0d) || this.loss >= 10) {
                d += loss(valueOf);
            }
        }
        return (d / ((i2 - i) + 1)) + (0.5d * this.lambda * this.dot.valueOf(this.w, this.w));
    }

    private void calibrate(int i, int i2) {
        double d = 0.0d;
        double d2 = 0.0d;
        int i3 = i;
        while (i3 <= i2) {
            double d3 = d + 1.0d;
            for (double d4 : this.tlist.get(i3).sample) {
                if (d4 != 0.0d) {
                    d2 += 1.0d;
                }
            }
            i3++;
            d = d3 + 1.0d;
        }
        this.skip = (int) (((8.0d * d) * this.w.length) / d2);
    }

    private double[] computeRatio(double[] dArr, double d, double[] dArr2, double[] dArr3, double d2) {
        double[] dArr4 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            double d3 = dArr2[i] - dArr3[i];
            if (d3 != 0.0d) {
                dArr4[i] = d3 / ((d * d3) + (d2 * dArr[i]));
            } else {
                dArr4[i] = 1.0d / d;
            }
        }
        return dArr4;
    }

    private void combineAndClip(double[] dArr, double d, double[] dArr2, double d2, double d3, double d4) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = (dArr[i] * d) + (dArr2[i] * d2);
            dArr[i] = Math.min(Math.max(dArr[i], d3), d4);
        }
    }

    private long determineT0(int i, int i2) {
        long j = 1;
        long j2 = 1;
        double d = Double.MAX_VALUE;
        for (int i3 = 0; i3 <= 10; i3++) {
            initSVM();
            this.t0 = j2;
            calibrate(i, i2);
            train(i, i2);
            double test = test(i, i2);
            if (test < d && !Double.isNaN(test)) {
                j = j2;
                d = test;
            }
            j2 *= 10;
        }
        return j;
    }

    private double loss(double d) {
        switch (this.loss) {
            case 1:
                if (d < 1.0d) {
                    return 1.0d - d;
                }
                return 0.0d;
            case 2:
                if (d < 0.0d) {
                    return 0.5d - d;
                }
                if (d < 1.0d) {
                    return 0.5d * (1.0d - d) * (1.0d - d);
                }
                return 0.0d;
            case 3:
                if (d < 1.0d) {
                    return 0.5d * (1.0d - d) * (1.0d - d);
                }
                return 0.0d;
            case 4:
            case 5:
            case 6:
            case 7:
            case 8:
            case 9:
            default:
                return 0.0d;
            case 10:
                return d >= 0.0d ? Math.log(1.0d + Math.exp(-d)) : (-d) + Math.log(1.0d + Math.exp(d));
            case 11:
                return d >= 1.0d ? Math.log(1.0d + Math.exp(1.0d - d)) : (1.0d - d) + Math.log(1.0d + Math.exp(d - 1.0d));
        }
    }

    private double dloss(double d) {
        switch (this.loss) {
            case 2:
                if (d < 0.0d) {
                    return 1.0d;
                }
                if (d < 1.0d) {
                    return 1.0d - d;
                }
                return 0.0d;
            case 3:
                if (d < 1.0d) {
                    return 1.0d - d;
                }
                return 0.0d;
            case 4:
            case 5:
            case 6:
            case 7:
            case 8:
            case 9:
            default:
                return d < 1.0d ? 1.0d : 0.0d;
            case 10:
                if (d < 0.0d) {
                    return 1.0d / (Math.exp(d) + 1.0d);
                }
                double exp = Math.exp(-d);
                return exp / (exp + 1.0d);
            case 11:
                if (d < 1.0d) {
                    return 1.0d / (Math.exp(d - 1.0d) + 1.0d);
                }
                double exp2 = Math.exp(1.0d - d);
                return exp2 / (exp2 + 1.0d);
        }
    }

    public int getLoss() {
        return this.loss;
    }

    public void setLoss(int i) {
        this.loss = i;
    }

    public double[] getW() {
        return this.w;
    }

    public void setW(double[] dArr) {
        this.w = dArr;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setEpochs(int i) {
        this.epochs = i;
    }

    public boolean isNormalize() {
        return this.normalize;
    }

    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    public double getC() {
        return this.C;
    }

    public void setC(double d) {
        this.C = d;
        this.hasC = true;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public Classifier<double[]> copy2() throws CloneNotSupportedException {
        return (DoubleSGDQN) super.clone();
    }
}
