package edu.iu.dsc.tws.examples.ml.svm.sgd.pegasos;

import edu.iu.dsc.tws.examples.ml.svm.exceptions.MatrixMultiplicationException;
import edu.iu.dsc.tws.examples.ml.svm.exceptions.NullDataSetException;
import edu.iu.dsc.tws.examples.ml.svm.math.Initializer;
import edu.iu.dsc.tws.examples.ml.svm.math.Matrix;
import edu.iu.dsc.tws.examples.ml.svm.sgd.SgdSvm;
import java.io.Serializable;
import java.util.logging.Logger;

/* loaded from: input_file:edu/iu/dsc/tws/examples/ml/svm/sgd/pegasos/PegasosSgdSvm.class */
public class PegasosSgdSvm extends SgdSvm implements Serializable {
    private static final long serialVersionUID = -8279454451787246995L;
    private static final Logger LOG = Logger.getLogger(PegasosSgdSvm.class.getName());
    private static int epoch = 0;
    private double[] wa;
    private int features;
    private double[] xyia;

    public PegasosSgdSvm(double[] dArr, double d, double d2, int i) {
        super(dArr, d, d2, i);
        this.features = 0;
    }

    public PegasosSgdSvm(double[] dArr, double[] dArr2, double d, double d2, int i) {
        super(dArr2, d, d2, i);
        this.features = 0;
    }

    public PegasosSgdSvm(double[] dArr, double d, int i, int i2) {
        super(dArr, d, i, i2);
        this.features = 0;
        this.features = i2;
        this.w = Initializer.initialWeights(this.features);
    }

    public PegasosSgdSvm(double[] dArr, double[][] dArr2, double[] dArr3, double d, int i, int i2) {
        super(dArr, dArr2, dArr3, d, i);
        this.features = 0;
        this.features = i2;
        if (dArr == null) {
            this.w = Initializer.initialWeights(this.features);
        } else {
            this.w = dArr;
        }
    }

    @Override // edu.iu.dsc.tws.examples.ml.svm.sgd.SgdSvm
    @Deprecated
    public void sgd() throws NullDataSetException, MatrixMultiplicationException {
        if (this.isInvalid) {
            throw new NullDataSetException("Invalid data source with no features or no data");
        }
    }

    @Override // edu.iu.dsc.tws.examples.ml.svm.sgd.SgdSvm
    public void iterativeSgd(double[] dArr, double[][] dArr2, double[] dArr3) throws NullDataSetException, MatrixMultiplicationException {
        double[] subtract;
        double[] dArr4 = dArr;
        for (int i = 0; i < this.iterations; i++) {
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                if (dArr3[i2] * Matrix.dot(dArr2[i2], dArr4) < 1.0d) {
                    this.xyia = new double[dArr2.length];
                    this.xyia = Matrix.scalarMultiply(Matrix.subtract(dArr4, Matrix.scalarMultiply(dArr2[i2], dArr3[i2])), this.alpha);
                    subtract = Matrix.subtract(dArr4, this.xyia);
                } else {
                    this.wa = new double[dArr2.length];
                    this.wa = Matrix.scalarMultiply(dArr4, this.alpha);
                    subtract = Matrix.subtract(dArr4, this.wa);
                }
                dArr4 = subtract;
            }
        }
        setW(dArr4);
    }

    @Override // edu.iu.dsc.tws.examples.ml.svm.sgd.SgdSvm
    public void iterativeTaskSgd(double[] dArr, double[][] dArr2, double[] dArr3) throws NullDataSetException, MatrixMultiplicationException {
        double[] subtract;
        double[] dArr4 = this.w;
        for (int i = 0; i < dArr2.length; i++) {
            if (dArr3[i] * Matrix.dot(dArr2[i], dArr4) < 1.0d) {
                this.xyia = new double[dArr2.length];
                this.xyia = Matrix.scalarMultiply(Matrix.subtract(dArr4, Matrix.scalarMultiply(dArr2[i], dArr3[i])), this.alpha);
                subtract = Matrix.subtract(dArr4, this.xyia);
            } else {
                this.wa = new double[dArr2.length];
                this.wa = Matrix.scalarMultiply(dArr4, this.alpha);
                subtract = Matrix.subtract(dArr4, this.wa);
            }
            dArr4 = subtract;
        }
        setW(dArr4);
    }

    @Override // edu.iu.dsc.tws.examples.ml.svm.sgd.SgdSvm
    public void onlineSGD(double[] dArr, double[] dArr2, double d) throws NullDataSetException, MatrixMultiplicationException {
        double[] subtract;
        if (d * Matrix.dot(dArr2, dArr) < 1.0d) {
            this.xyia = new double[dArr2.length];
            this.xyia = Matrix.scalarMultiply(Matrix.subtract(dArr, Matrix.scalarMultiply(dArr2, d)), this.alpha);
            subtract = Matrix.subtract(dArr, this.xyia);
        } else {
            this.wa = new double[dArr2.length];
            this.wa = Matrix.scalarMultiply(dArr, this.alpha);
            subtract = Matrix.subtract(dArr, this.wa);
        }
        setW(subtract);
    }

    @Override // edu.iu.dsc.tws.examples.ml.svm.sgd.SgdSvm
    public <T> void onlineDynamicSGD(T[] tArr, T[] tArr2, T t) {
    }

    @Override // edu.iu.dsc.tws.examples.ml.svm.sgd.SgdSvm
    public double[] getW() {
        return this.w;
    }

    @Override // edu.iu.dsc.tws.examples.ml.svm.sgd.SgdSvm
    public void setW(double[] dArr) {
        this.w = dArr;
    }
}
