package net.jkernelmachines.classifier;

import java.io.Serializable;
import java.util.Collections;
import java.util.List;
import net.jkernelmachines.kernel.typed.DoubleLinear;
import net.jkernelmachines.type.TrainingSample;
import net.jkernelmachines.type.TrainingSampleStream;

/* loaded from: input_file:net/jkernelmachines/classifier/DoubleSGD.class */
public class DoubleSGD implements Classifier<double[]>, OnlineClassifier<double[]>, Serializable {
    private static final long serialVersionUID = 3245177176254451010L;
    public static final int HINGELOSS = 1;
    public static final int SMOOTHHINGELOSS = 2;
    public static final int SQUAREDHINGELOSS = 3;
    public static final int LOGLOSS = 10;
    public static final int LOGLOSSMARGIN = 11;
    double bias;
    private long t;
    private double wscale;
    private int loss = 1;
    private double[] w = null;
    boolean hasBias = true;
    private double lambda = 1.0E-4d;
    private int epochs = 5;
    private boolean shuffle = false;
    DoubleLinear linear = new DoubleLinear();

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<double[]>> list) {
        if (list.isEmpty()) {
            return;
        }
        this.w = new double[list.get(0).sample.length];
        this.wscale = 1.0d;
        this.bias = 0.0d;
        double sqrt = Math.sqrt(1.0d / Math.sqrt(this.lambda));
        this.t = (long) (1.0d / ((sqrt / Math.max(1.0d, dloss(-sqrt))) * this.lambda));
        for (int i = 0; i < this.epochs; i++) {
            trainOnce(list);
        }
    }

    @Override // net.jkernelmachines.classifier.OnlineClassifier
    public void train(TrainingSample<double[]> trainingSample) {
        if (this.w == null) {
            this.w = new double[trainingSample.sample.length];
            this.wscale = 1.0d;
            this.bias = 0.0d;
            double sqrt = Math.sqrt(1.0d / Math.sqrt(this.lambda));
            this.t = (long) (1.0d / ((sqrt / Math.max(1.0d, dloss(-sqrt))) * this.lambda));
        }
        double d = 1.0d / (this.lambda * this.t);
        this.wscale *= 1.0d - (d * this.lambda);
        if (this.wscale < 1.0E-9d) {
            for (int i = 0; i < this.w.length; i++) {
                double[] dArr = this.w;
                int i2 = i;
                dArr[i2] = dArr[i2] * this.wscale;
            }
            this.wscale = 1.0d;
        }
        double[] dArr2 = trainingSample.sample;
        double d2 = trainingSample.label;
        double valueOf = d2 * ((this.linear.valueOf(this.w, dArr2) * this.wscale) + this.bias);
        if (valueOf < 1.0d && this.loss < 10) {
            double dloss = d * dloss(valueOf);
            for (int i3 = 0; i3 < this.w.length; i3++) {
                double[] dArr3 = this.w;
                int i4 = i3;
                dArr3[i4] = dArr3[i4] + (((dArr2[i3] * dloss) * d2) / this.wscale);
            }
            if (this.hasBias) {
                this.bias += dloss * d2 * 0.01d;
            }
        }
        this.t++;
    }

    @Override // net.jkernelmachines.classifier.OnlineClassifier
    public void onlineTrain(TrainingSampleStream<double[]> trainingSampleStream) {
        while (true) {
            TrainingSample<double[]> nextSample = trainingSampleStream.nextSample();
            if (nextSample == null) {
                return;
            } else {
                train(nextSample);
            }
        }
    }

    public void trainOnce(List<TrainingSample<double[]>> list) {
        if (this.w == null) {
            return;
        }
        int size = list.size();
        if (this.shuffle) {
            Collections.shuffle(list);
        }
        for (int i = 0; i < size; i++) {
            double d = 1.0d / (this.lambda * this.t);
            this.wscale *= 1.0d - (d * this.lambda);
            if (this.wscale < 1.0E-9d) {
                for (int i2 = 0; i2 < this.w.length; i2++) {
                    double[] dArr = this.w;
                    int i3 = i2;
                    dArr[i3] = dArr[i3] * this.wscale;
                }
                this.wscale = 1.0d;
            }
            double[] dArr2 = list.get(i).sample;
            double d2 = list.get(i).label;
            double valueOf = d2 * ((this.linear.valueOf(this.w, dArr2) * this.wscale) + this.bias);
            if (valueOf < 1.0d && this.loss < 10) {
                double dloss = d * dloss(valueOf);
                for (int i4 = 0; i4 < this.w.length; i4++) {
                    double[] dArr3 = this.w;
                    int i5 = i4;
                    dArr3[i5] = dArr3[i5] + (((dArr2[i4] * dloss) * d2) / this.wscale);
                }
                if (this.hasBias) {
                    this.bias += dloss * d2 * 0.01d;
                }
            }
            this.t++;
        }
        if (this.hasBias) {
            return;
        }
        this.bias = 0.0d;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public double valueOf(double[] dArr) {
        return (this.linear.valueOf(this.w, dArr) * this.wscale) + this.bias;
    }

    private double dloss(double d) {
        switch (this.loss) {
            case 2:
                if (d < 0.0d) {
                    return 1.0d;
                }
                if (d < 1.0d) {
                    return 1.0d - d;
                }
                return 0.0d;
            case 3:
                if (d < 1.0d) {
                    return 1.0d - d;
                }
                return 0.0d;
            case 4:
            case 5:
            case 6:
            case 7:
            case 8:
            case 9:
            default:
                return d < 1.0d ? 1.0d : 0.0d;
            case 10:
                if (d < 0.0d) {
                    return 1.0d / (Math.exp(d) + 1.0d);
                }
                double exp = Math.exp(-d);
                return exp / (exp + 1.0d);
            case 11:
                if (d < 1.0d) {
                    return 1.0d / (Math.exp(d - 1.0d) + 1.0d);
                }
                double exp2 = Math.exp(1.0d - d);
                return exp2 / (exp2 + 1.0d);
        }
    }

    public double[] getW() {
        return this.w;
    }

    public int getLoss() {
        return this.loss;
    }

    public void setLoss(int i) {
        this.loss = i;
    }

    public void setLambda(double d) {
        this.lambda = d;
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public void setHasBias(boolean z) {
        this.hasBias = z;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setEpochs(int i) {
        this.epochs = i;
    }

    public boolean isShuffle() {
        return this.shuffle;
    }

    public void setShuffle(boolean z) {
        this.shuffle = z;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    /* renamed from: copy */
    public Classifier<double[]> copy2() throws CloneNotSupportedException {
        return (DoubleSGD) super.clone();
    }

    public double getLambda() {
        return this.lambda;
    }
}
