package net.jkernelmachines.classifier;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import net.jkernelmachines.kernel.Kernel;
import net.jkernelmachines.type.TrainingSample;
import net.jkernelmachines.type.TrainingSampleStream;
import net.jkernelmachines.util.algebra.VectorOperations;

/* loaded from: input_file:net/jkernelmachines/classifier/SDCA.class */
public class SDCA<T> implements KernelSVM<T>, OnlineClassifier<T> {
    Kernel<T> kernel;
    T[] samples;
    int[] labels;
    double[] alphas;
    List<TrainingSample<T>> train;
    private int n;
    private double[][] km;
    double C = 1.0d;
    int E = 5;
    private boolean cacheKernel = true;

    public SDCA(Kernel<T> kernel) {
        this.kernel = kernel;
    }

    @Override // net.jkernelmachines.classifier.OnlineClassifier
    public void train(TrainingSample<T> trainingSample) {
        if (this.train == null) {
            this.train = new ArrayList();
            this.alphas = new double[0];
            this.samples = (T[]) new Object[0];
        }
        if (this.train.contains(trainingSample)) {
            updateNoCache(this.train.indexOf(trainingSample));
            return;
        }
        this.train.add(trainingSample);
        this.alphas = Arrays.copyOf(this.alphas, this.alphas.length + 1);
        this.samples = (T[]) Arrays.copyOf(this.samples, this.samples.length + 1);
        this.samples[this.samples.length - 1] = trainingSample.sample;
        updateNoCache(this.train.size() - 1);
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<T>> list) {
        this.n = list.size();
        this.train = new ArrayList(this.n);
        this.train.addAll(list);
        if (this.cacheKernel && list.size() * list.size() * 8 < Runtime.getRuntime().freeMemory()) {
            this.km = this.kernel.getKernelMatrix(this.train);
        }
        this.samples = (T[]) new Object[this.n];
        this.labels = new int[this.n];
        for (int i = 0; i < this.n; i++) {
            TrainingSample<T> trainingSample = list.get(i);
            this.samples[i] = trainingSample.sample;
            this.labels[i] = trainingSample.label;
        }
        this.alphas = new double[this.n];
        ArrayList arrayList = new ArrayList(this.n);
        for (int i2 = 0; i2 < this.n; i2++) {
            arrayList.add(Integer.valueOf(i2));
        }
        if (this.cacheKernel) {
            for (int i3 = 0; i3 < this.E; i3++) {
                Collections.shuffle(arrayList);
                for (int i4 = 0; i4 < this.n; i4++) {
                    update(((Integer) arrayList.get(i4)).intValue());
                }
            }
            return;
        }
        for (int i5 = 0; i5 < this.E; i5++) {
            Collections.shuffle(arrayList);
            for (int i6 = 0; i6 < this.n; i6++) {
                updateNoCache(((Integer) arrayList.get(i6)).intValue());
            }
        }
    }

    @Override // net.jkernelmachines.classifier.OnlineClassifier
    public void onlineTrain(TrainingSampleStream<T> trainingSampleStream) {
        if (this.train == null || this.alphas == null) {
            this.train = new ArrayList();
            this.alphas = new double[0];
            this.samples = (T[]) new Object[0];
        }
        this.cacheKernel = false;
        int i = 0;
        while (true) {
            TrainingSample<T> nextSample = trainingSampleStream.nextSample();
            if (nextSample == null) {
                System.out.println("\r");
                return;
            }
            i++;
            if (this.train.contains(nextSample)) {
                updateNoCache(this.train.indexOf(nextSample));
            } else {
                this.train.add(nextSample);
                this.alphas = Arrays.copyOf(this.alphas, this.alphas.length + 1);
                this.samples = (T[]) Arrays.copyOf(this.samples, this.samples.length + 1);
                this.samples[this.samples.length - 1] = nextSample.sample;
                updateNoCache(this.train.size() - 1);
            }
            System.out.print("\r" + i);
        }
    }

    private final void update(int i) {
        double d = this.labels[i];
        this.alphas[i] = d * Math.max(0.0d, Math.min(this.C, ((1.0d - (d * VectorOperations.dot(this.alphas, this.km[i]))) / this.km[i][i]) + (d * this.alphas[i])));
    }

    private final void updateNoCache(int i) {
        TrainingSample<T> trainingSample = this.train.get(i);
        double d = trainingSample.label;
        this.alphas[i] = d * Math.max(0.0d, Math.min(this.C, ((1.0d - (d * valueOf(trainingSample.sample))) / this.kernel.valueOf(trainingSample.sample, trainingSample.sample)) + (d * this.alphas[i])));
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public double valueOf(T t) {
        double d = 0.0d;
        for (int i = 0; i < this.alphas.length; i++) {
            d += this.alphas[i] * this.kernel.valueOf(this.samples[i], t);
        }
        return d;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    /* renamed from: copy */
    public Classifier<T> copy2() throws CloneNotSupportedException {
        return (SDCA) clone();
    }

    public int getE() {
        return this.E;
    }

    public void setE(int i) {
        this.E = i;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setKernel(Kernel<T> kernel) {
        this.kernel = kernel;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double[] getAlphas() {
        double[] dArr = new double[this.alphas.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.alphas[i] * this.train.get(i).label;
        }
        return dArr;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setC(double d) {
        this.C = d;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double getC() {
        return this.C;
    }

    public double getObjective() {
        double d = 0.0d;
        for (int i = 0; i < this.n; i++) {
            for (int i2 = 0; i2 < this.n; i2++) {
                d += 0.5d * this.alphas[i] * this.alphas[i2] * this.km[i][i2];
            }
        }
        for (int i3 = 0; i3 < this.n; i3++) {
            d += this.C * Math.max(0.0d, 1.0d - (this.labels[i3] * valueOf(this.samples[i3])));
        }
        return d / this.n;
    }

    public double getDualObjective() {
        double d = 0.0d;
        for (int i = 0; i < this.n; i++) {
            d += this.labels[i] * this.alphas[i];
        }
        for (int i2 = 0; i2 < this.n; i2++) {
            for (int i3 = 0; i3 < this.n; i3++) {
                d -= ((0.5d * this.alphas[i2]) * this.alphas[i3]) * this.km[i2][i3];
            }
        }
        return d / this.n;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public Kernel<T> getKernel() {
        return this.kernel;
    }

    public boolean isCacheKernel() {
        return this.cacheKernel;
    }

    public void setCacheKernel(boolean z) {
        this.cacheKernel = z;
    }
}
