package net.jkernelmachines.classifier;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import net.jkernelmachines.kernel.Kernel;
import net.jkernelmachines.kernel.adaptative.ThreadedSumKernel;
import net.jkernelmachines.kernel.typed.DoubleLinear;
import net.jkernelmachines.type.TrainingSample;

/* loaded from: input_file:net/jkernelmachines/classifier/TSMKL.class */
public class TSMKL<T> implements KernelSVM<T>, MKL<T> {
    LaSVM<T> lasvm;
    List<Kernel<T>> kernels;
    double[] beta;
    List<TrainingSample<T>> tlist;
    double lambda = 0.001d;
    double C = 10.0d;
    double t = 1.0d;
    private DoubleLinear linear = new DoubleLinear();

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<T>> list) {
        this.tlist = new ArrayList();
        this.tlist.addAll(list);
        if (this.kernels == null || this.kernels.isEmpty()) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Kernel<T>> it = this.kernels.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getKernelMatrix(this.tlist));
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < this.tlist.size(); i++) {
            TrainingSample<T> trainingSample = this.tlist.get(i);
            for (int i2 = 0; i2 < this.tlist.size(); i2++) {
                TrainingSample<T> trainingSample2 = this.tlist.get(i2);
                double[] dArr = new double[this.kernels.size()];
                for (int i3 = 0; i3 < arrayList.size(); i3++) {
                    dArr[i3] = ((double[][]) arrayList.get(i3))[i][i2];
                }
                arrayList2.add(new TrainingSample<>(dArr, trainingSample.label * trainingSample2.label));
            }
        }
        Collections.shuffle(arrayList2);
        List<TrainingSample<double[]>> subList = arrayList2.subList(0, arrayList2.size() / 2);
        double d = 1.0d;
        double d2 = Double.MAX_VALUE;
        for (int i4 = 0; i4 < 9; i4++) {
            this.t = Math.pow(10.0d, i4);
            init();
            trainOnce(subList);
            int i5 = 0;
            Iterator<TrainingSample<double[]>> it2 = arrayList2.subList(arrayList2.size() / 2, arrayList2.size()).iterator();
            while (it2.hasNext()) {
                if (r0.label * this.linear.valueOf(it2.next().sample, this.beta) < 0.0d) {
                    i5++;
                }
            }
            if (i5 < d2) {
                d = i4;
                d2 = i5;
            }
        }
        this.t = Math.pow(10.0d, d);
        init();
        for (int i6 = 0; i6 < 5; i6++) {
            trainOnce(arrayList2);
        }
        learnSVM();
    }

    private void init() {
        this.beta = new double[this.kernels.size()];
    }

    private void trainOnce(List<TrainingSample<double[]>> list) {
        int size = list.size();
        double d = 1.0d;
        for (int i = 0; i < size; i++) {
            double d2 = 1.0d / (this.lambda * this.t);
            d *= 1.0d - (d2 * this.lambda);
            if (d < 1.0E-9d) {
                for (int i2 = 0; i2 < this.beta.length; i2++) {
                    double[] dArr = this.beta;
                    int i3 = i2;
                    dArr[i3] = dArr[i3] * d;
                }
                d = 1.0d;
            }
            double[] dArr2 = list.get(i).sample;
            double d3 = list.get(i).label;
            if (d3 * this.linear.valueOf(dArr2, this.beta) < 1.0d) {
                for (int i4 = 0; i4 < this.beta.length; i4++) {
                    double[] dArr3 = this.beta;
                    int i5 = i4;
                    dArr3[i5] = dArr3[i5] + (((d2 * dArr2[i4]) * d3) / d);
                    if (this.beta[i4] < 0.0d) {
                        this.beta[i4] = 0.0d;
                    }
                }
            }
            this.t += 1.0d;
        }
        for (int i6 = 0; i6 < this.beta.length; i6++) {
            double[] dArr4 = this.beta;
            int i7 = i6;
            dArr4[i7] = dArr4[i7] * d;
        }
    }

    private void learnSVM() {
        ThreadedSumKernel threadedSumKernel = new ThreadedSumKernel();
        for (int i = 0; i < this.kernels.size(); i++) {
            if (this.beta[i] != 0.0d) {
                threadedSumKernel.addKernel(this.kernels.get(i), this.beta[i]);
            }
        }
        this.lasvm = new LaSVM<>(threadedSumKernel);
        this.lasvm.setC(this.C);
        this.lasvm.train(this.tlist);
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public double valueOf(T t) {
        return this.lasvm.valueOf(t);
    }

    @Override // net.jkernelmachines.classifier.MKL
    public void addKernel(Kernel<T> kernel) {
        if (this.kernels == null) {
            this.kernels = new ArrayList();
        }
        this.kernels.add(kernel);
    }

    @Override // net.jkernelmachines.classifier.Classifier
    /* renamed from: copy */
    public TSMKL<T> copy2() throws CloneNotSupportedException {
        return (TSMKL) super.clone();
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setKernel(Kernel<T> kernel) {
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double[] getAlphas() {
        return this.lasvm.getAlphas();
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setC(double d) {
        this.C = d;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double getC() {
        return this.C;
    }

    @Override // net.jkernelmachines.classifier.MKL
    public double[] getKernelWeights() {
        return this.beta;
    }

    @Override // net.jkernelmachines.classifier.MKL
    public List<Kernel<T>> getKernels() {
        return this.kernels;
    }

    @Override // net.jkernelmachines.classifier.MKL
    public Map<Kernel<T>, Double> getKernelWeightMap() {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.kernels.size(); i++) {
            hashMap.put(this.kernels.get(i), Double.valueOf(this.beta[i]));
        }
        return hashMap;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public Kernel<T> getKernel() {
        return this.lasvm.getKernel();
    }
}
