package net.jkernelmachines.classifier;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import net.jkernelmachines.type.TrainingSample;
import net.jkernelmachines.util.DebugPrinter;
import net.jkernelmachines.util.algebra.VectorOperations;

/* loaded from: input_file:net/jkernelmachines/classifier/DoubleSAG.class */
public class DoubleSAG implements Classifier<double[]>, Serializable {
    private static final long serialVersionUID = -3497156096402090039L;
    public static final int HINGELOSS = 1;
    public static final int SMOOTHHINGELOSS = 2;
    public static final int SQUAREDHINGELOSS = 3;
    public static final int LOGLOSS = 10;
    public static final int LOGLOSSMARGIN = 11;
    double[] w;
    private double alpha;
    private double[] yi;
    private double db;
    private double[] d;
    private int dim;
    private int n;
    private int loss = 1;
    private boolean cyclic = true;
    double b = 0.0d;
    double lambda = 1.0E-4d;
    long E = 4;
    transient DebugPrinter debug = new DebugPrinter();

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<double[]>> list) {
        this.dim = list.get(0).sample.length;
        this.n = list.size();
        this.yi = new double[list.size()];
        this.w = new double[this.dim];
        this.b = 0.0d;
        double d = 0.0d;
        Iterator<TrainingSample<double[]>> it = list.iterator();
        while (it.hasNext()) {
            double n2 = VectorOperations.n2(it.next().sample);
            if (n2 > d) {
                d = n2;
            }
        }
        this.alpha = 1.0d / (4.0d * d);
        this.d = new double[this.dim];
        this.db = 0.0d;
        this.debug.println(3, "First epoch");
        for (int i = 0; i < list.size(); i++) {
            double[] dArr = list.get(i).sample;
            int i2 = list.get(i).label;
            VectorOperations.addi(this.d, this.d, (-this.yi[i]) * i2, dArr);
            this.db -= this.yi[i] * i2;
            this.yi[i] = dloss(i2 * valueOf(dArr));
            VectorOperations.addi(this.d, this.d, this.yi[i] * i2, dArr);
            this.db += this.yi[i] * i2;
            VectorOperations.muli(this.w, this.w, 1.0d - (this.alpha * this.lambda));
            VectorOperations.addi(this.w, this.w, this.alpha / (i + 1), this.d);
            this.b = ((1.0d - (this.alpha * this.lambda)) * this.b) + ((this.alpha * this.db) / (i + 1));
        }
        ArrayList arrayList = new ArrayList(this.n);
        for (int i3 = 0; i3 < this.n; i3++) {
            arrayList.add(Integer.valueOf(i3));
        }
        for (int i4 = 0; i4 < this.E; i4++) {
            this.debug.println(3, "epoch " + i4);
            if (!this.cyclic) {
                Collections.shuffle(arrayList);
            }
            for (int i5 = 0; i5 < list.size(); i5++) {
                int i6 = i5;
                if (!this.cyclic) {
                    i6 = ((Integer) arrayList.get(i5)).intValue();
                }
                update(i6, list.get(i6).sample, list.get(i6).label);
            }
        }
        this.debug.println(3, "w: " + Arrays.toString(this.w));
        this.debug.println(3, "b: " + this.b);
    }

    private final void update(int i, double[] dArr, int i2) {
        VectorOperations.addi(this.d, this.d, (-this.yi[i]) * i2, dArr);
        this.db -= this.yi[i] * i2;
        this.yi[i] = dloss(i2 * valueOf(dArr));
        VectorOperations.addi(this.d, this.d, this.yi[i] * i2, dArr);
        this.db += this.yi[i] * i2;
        VectorOperations.muli(this.w, this.w, 1.0d - (this.alpha * this.lambda));
        VectorOperations.addi(this.w, this.w, this.alpha / this.n, this.d);
        this.b = ((1.0d - (this.alpha * this.lambda)) * this.b) + ((this.alpha * this.db) / this.n);
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public double valueOf(double[] dArr) {
        return VectorOperations.dot(this.w, dArr);
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public Classifier<double[]> copy() throws CloneNotSupportedException {
        return (DoubleSAG) clone();
    }

    private double dloss(double d) {
        switch (this.loss) {
            case 2:
                if (d < 0.0d) {
                    return 1.0d;
                }
                if (d < 1.0d) {
                    return 1.0d - d;
                }
                return 0.0d;
            case 3:
                if (d < 1.0d) {
                    return 1.0d - d;
                }
                return 0.0d;
            case 4:
            case 5:
            case 6:
            case 7:
            case 8:
            case 9:
            default:
                return d < 1.0d ? 1.0d : 0.0d;
            case 10:
                if (d < 0.0d) {
                    return 1.0d / (Math.exp(d) + 1.0d);
                }
                double exp = Math.exp(-d);
                return exp / (exp + 1.0d);
            case 11:
                if (d < 1.0d) {
                    return 1.0d / (Math.exp(d - 1.0d) + 1.0d);
                }
                double exp2 = Math.exp(1.0d - d);
                return exp2 / (exp2 + 1.0d);
        }
    }

    public int getLoss() {
        return this.loss;
    }

    public void setLoss(int i) {
        this.loss = i;
    }

    public double[] getW() {
        return this.w;
    }

    public void setW(double[] dArr) {
        this.w = dArr;
    }

    public double getB() {
        return this.b;
    }

    public void setB(double d) {
        this.b = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public long getE() {
        return this.E;
    }

    public void setE(long j) {
        this.E = j;
    }

    public boolean isCyclic() {
        return this.cyclic;
    }

    public void setCyclic(boolean z) {
        this.cyclic = z;
    }
}
