package net.jkernelmachines.classifier.multiclass;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import net.jkernelmachines.classifier.Classifier;
import net.jkernelmachines.classifier.KernelSVM;
import net.jkernelmachines.kernel.Kernel;
import net.jkernelmachines.type.TrainingSample;
import net.jkernelmachines.util.algebra.MatrixOperations;
import net.jkernelmachines.util.algebra.MatrixVectorOperations;
import net.jkernelmachines.util.algebra.VectorOperations;

/* loaded from: input_file:net/jkernelmachines/classifier/multiclass/MulticlassSDCA.class */
public class MulticlassSDCA<T> implements MulticlassClassifier<T>, KernelSVM<T> {
    Kernel<T> kernel;
    List<TrainingSample<T>> tlist;
    double[][] alpha;
    List<Integer> classes;
    double C = 1.0d;
    double E = 25.0d;
    int nb_classes = 0;

    public MulticlassSDCA(Kernel<T> kernel) {
        this.kernel = kernel;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<T>> list) {
        this.tlist = new ArrayList();
        this.tlist.addAll(list);
        double[][] kernelMatrix = this.kernel.getKernelMatrix(this.tlist);
        this.nb_classes = 0;
        this.classes = new ArrayList();
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (TrainingSample<T> trainingSample : this.tlist) {
            if (!this.classes.contains(Integer.valueOf(trainingSample.label))) {
                this.nb_classes++;
                this.classes.add(Integer.valueOf(trainingSample.label));
            }
            int i2 = i;
            i++;
            arrayList.add(Integer.valueOf(i2));
        }
        this.alpha = new double[this.tlist.size()][this.nb_classes];
        double[][] dArr = new double[this.nb_classes][this.tlist.size()];
        double[] dArr2 = new double[this.nb_classes];
        for (int i3 = 0; i3 < this.E * this.nb_classes; i3++) {
            Collections.shuffle(arrayList);
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                int intValue = ((Integer) it.next()).intValue();
                int indexOf = this.classes.indexOf(Integer.valueOf(this.tlist.get(intValue).label));
                MatrixVectorOperations.rMuli(dArr2, MatrixOperations.transi(dArr, this.alpha), kernelMatrix[intValue]);
                boolean z = true;
                for (int i4 = 0; i4 < this.nb_classes; i4++) {
                    if (i4 != indexOf && dArr2[i4] >= dArr2[indexOf]) {
                        z = false;
                    }
                }
                if (!z) {
                    double[] dArr3 = new double[this.nb_classes];
                    Arrays.fill(dArr3, 1.0d);
                    double[] dArr4 = new double[this.nb_classes];
                    for (int i5 = 0; i5 < this.tlist.size(); i5++) {
                        VectorOperations.muli(dArr4, this.alpha[i5], kernelMatrix[intValue][i5]);
                        VectorOperations.addi(dArr3, dArr3, 1.0d, dArr4);
                    }
                    VectorOperations.muli(dArr3, dArr3, (-1.0d) / kernelMatrix[intValue][intValue]);
                    dArr3[indexOf] = 0.0d;
                    double d = 0.0d;
                    for (int i6 = 0; i6 < this.nb_classes; i6++) {
                        if (i6 != indexOf) {
                            if (this.alpha[intValue][i6] >= 0.0d && dArr3[i6] > 0.0d) {
                                dArr3[i6] = 0.0d;
                            }
                            if (this.alpha[intValue][i6] + dArr3[i6] > 0.0d) {
                                dArr3[i6] = -this.alpha[intValue][i6];
                            }
                            d += dArr3[i6];
                        }
                    }
                    if (this.alpha[intValue][indexOf] - d > this.C) {
                        for (int i7 = 0; i7 < this.nb_classes; i7++) {
                            if (i7 != indexOf) {
                                int i8 = i7;
                                dArr3[i8] = dArr3[i8] * ((this.alpha[intValue][indexOf] - this.C) / d);
                            }
                        }
                        d = this.alpha[intValue][indexOf] - this.C;
                    }
                    dArr3[indexOf] = -d;
                    double[] add = VectorOperations.add(this.alpha[intValue], 1.0d, dArr3);
                    for (int i9 = 0; i9 < this.nb_classes; i9++) {
                        if (i9 != indexOf && add[i9] > 0.0d) {
                            add[i9] = 0.0d;
                        }
                    }
                    double d2 = 0.0d;
                    for (int i10 = 0; i10 < this.nb_classes; i10++) {
                        d2 += add[i10];
                    }
                    if (Math.abs(d2) > 1.0E-10d) {
                        System.out.println("error with " + intValue + " sum(a)= " + d2 + " a: " + Arrays.toString(add));
                        System.exit(0);
                    }
                    this.alpha[intValue] = add;
                }
            }
        }
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public double valueOf(T t) {
        if (this.nb_classes <= 0) {
            return 0.0d;
        }
        double[] dArr = new double[this.nb_classes];
        for (int i = 0; i < this.tlist.size(); i++) {
            VectorOperations.addi(dArr, dArr, this.kernel.valueOf(t, this.tlist.get(i).sample), this.alpha[i]);
        }
        int i2 = 0;
        double d = Double.NEGATIVE_INFINITY;
        for (int i3 = 0; i3 < this.nb_classes; i3++) {
            if (dArr[i3] > d) {
                i2 = i3;
                d = dArr[i3];
            }
        }
        return this.classes.get(i2).intValue();
    }

    @Override // net.jkernelmachines.classifier.Classifier
    /* renamed from: copy */
    public Classifier<T> copy2() throws CloneNotSupportedException {
        return (MulticlassSDCA) clone();
    }

    @Override // net.jkernelmachines.classifier.multiclass.MulticlassClassifier
    public double getConfidence(T t) {
        if (this.nb_classes <= 0) {
            return 0.0d;
        }
        double[] dArr = new double[this.nb_classes];
        for (int i = 0; i < this.tlist.size(); i++) {
            VectorOperations.addi(dArr, dArr, this.kernel.valueOf(t, this.tlist.get(i).sample), this.alpha[i]);
        }
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.nb_classes; i2++) {
            if (dArr[i2] > d) {
                d = dArr[i2];
            }
        }
        return d;
    }

    @Override // net.jkernelmachines.classifier.multiclass.MulticlassClassifier
    public Map<Integer, Double> getConfidences(T t) {
        if (this.nb_classes <= 0) {
            return null;
        }
        double[] dArr = new double[this.nb_classes];
        for (int i = 0; i < this.tlist.size(); i++) {
            VectorOperations.addi(dArr, dArr, this.kernel.valueOf(t, this.tlist.get(i).sample), this.alpha[i]);
        }
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            hashMap.put(this.classes.get(i2), Double.valueOf(dArr[i2]));
        }
        return hashMap;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public Kernel<T> getKernel() {
        return this.kernel;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setKernel(Kernel<T> kernel) {
        this.kernel = kernel;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double getC() {
        return this.C;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setC(double d) {
        this.C = d;
    }

    public double getE() {
        return this.E;
    }

    public void setE(double d) {
        this.E = d;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double[] getAlphas() {
        throw new RuntimeException("operation not possible");
    }

    public double[][] getMulticlassAlphas() {
        return this.alpha;
    }
}
