package net.jkernelmachines.classifier;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import net.jkernelmachines.kernel.Kernel;
import net.jkernelmachines.kernel.SimpleCacheKernel;
import net.jkernelmachines.kernel.adaptative.ThreadedSumKernel;
import net.jkernelmachines.type.TrainingSample;
import net.jkernelmachines.util.DebugPrinter;
import net.jkernelmachines.util.algebra.MatrixVectorOperations;
import net.jkernelmachines.util.algebra.VectorOperations;

/* loaded from: input_file:net/jkernelmachines/classifier/SequentialMKL.class */
public class SequentialMKL<T> implements KernelSVM<T>, MKL<T>, Serializable {
    private static final long serialVersionUID = 4200884871838129126L;
    List<TrainingSample<T>> tlist;
    List<Kernel<T>> kernels;
    Map<Kernel<T>, Double> kernelMap;
    Kernel<T> kernel;
    double[] alpha;
    double[] y;
    double[] beta;
    double[] yEst;
    DebugPrinter debug = new DebugPrinter();
    double C = 1.0d;
    double numPrec = 1.0E-10d;
    double kktPrec = 0.001d;
    long E = 250;
    long T = 10;
    int kernelStep = 10;
    int initKernel = 1;

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<T>> list) {
        this.tlist = new ArrayList();
        this.tlist.addAll(list);
        train();
    }

    private void train() {
        Kernel<T> kernel;
        Kernel<T> kernel2;
        this.y = new double[this.tlist.size()];
        for (int i = 0; i < this.y.length; i++) {
            this.y[i] = this.tlist.get(i).label;
        }
        LinkedList linkedList = new LinkedList();
        linkedList.addAll(this.kernels);
        Collections.shuffle(linkedList);
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < this.initKernel; i2++) {
            hashMap.put(new SimpleCacheKernel((Kernel) linkedList.remove(0), this.tlist), Double.valueOf(1.0d / this.initKernel));
        }
        double computePrimalObj = computePrimalObj(hashMap);
        if (this.initKernel > 1) {
            trainWeights(hashMap);
        }
        this.debug.println(1, "Initial obj: " + computePrimalObj);
        boolean z = true;
        int i3 = 0;
        while (true) {
            if (!z) {
                break;
            }
            int i4 = i3;
            i3++;
            if (i4 >= this.T) {
                break;
            }
            z = false;
            long currentTimeMillis = System.currentTimeMillis();
            computePrimalObj(hashMap);
            this.debug.println(4, "train in " + (System.currentTimeMillis() - currentTimeMillis) + "ms");
            int i5 = 0;
            Iterator<Kernel<T>> it = hashMap.keySet().iterator();
            while (it.hasNext()) {
                Kernel<T> next = it.next();
                if (hashMap.get(next).doubleValue() < this.numPrec) {
                    linkedList.add(next);
                    it.remove();
                    i5++;
                }
            }
            long currentTimeMillis2 = System.currentTimeMillis();
            int i6 = 0;
            for (int i7 = 0; i7 < this.alpha.length; i7++) {
                if (this.alpha[i7] != 0.0d) {
                    i6++;
                }
            }
            double[] dArr = new double[i6];
            ArrayList arrayList = new ArrayList(i6);
            for (int i8 = 0; i8 < this.tlist.size(); i8++) {
                if (this.alpha[i8] != 0.0d) {
                    arrayList.add(this.tlist.get(i8));
                    dArr[arrayList.size() - 1] = this.y[i8] * this.alpha[i8];
                }
            }
            this.debug.println(3, "nb non-zero alpha: " + i6);
            double d = 0.0d;
            double d2 = Double.POSITIVE_INFINITY;
            Iterator<Map.Entry<Kernel<T>, Double>> it2 = hashMap.entrySet().iterator();
            while (it2.hasNext()) {
                Kernel<T> key = it2.next().getKey();
                while (true) {
                    kernel2 = key;
                    if (!(kernel2 instanceof SimpleCacheKernel)) {
                        break;
                    } else {
                        key = ((SimpleCacheKernel) kernel2).getKernel();
                    }
                }
                double dot = VectorOperations.dot(dArr, MatrixVectorOperations.rMul(kernel2.getKernelMatrix(arrayList), dArr));
                if (dot > d) {
                    d = dot;
                }
                if (dot < d2) {
                    d2 = dot;
                }
            }
            double d3 = d;
            this.debug.println(2, "gmin: " + d2 + "\tgmax: " + d3 + "\tnb: " + hashMap.size());
            int[] iArr = {0};
            int[] iArr2 = {0};
            int size = linkedList.size();
            Iterator it3 = linkedList.iterator();
            while (it3.hasNext() && iArr2[0] < this.kernelStep) {
                Kernel<T> kernel3 = (Kernel) it3.next();
                while (true) {
                    kernel = kernel3;
                    if (!(kernel instanceof SimpleCacheKernel)) {
                        break;
                    } else {
                        kernel3 = ((SimpleCacheKernel) kernel).getKernel();
                    }
                }
                double dot2 = VectorOperations.dot(dArr, MatrixVectorOperations.rMul(kernel.getKernelMatrix(arrayList), dArr));
                iArr[0] = iArr[0] + 1;
                if ((dot2 - d3) / d3 > this.kktPrec) {
                    hashMap.put(kernel instanceof SimpleCacheKernel ? (SimpleCacheKernel) kernel : new SimpleCacheKernel(kernel, this.tlist), Double.valueOf(0.0d));
                    it3.remove();
                    iArr2[0] = iArr2[0] + 1;
                    z = true;
                }
            }
            this.debug.println(2, iArr2[0] + " added, " + iArr[0] + " scan in " + (System.currentTimeMillis() - currentTimeMillis2) + "ms (" + ((System.currentTimeMillis() - currentTimeMillis2) / iArr[0]) + "ms/k)");
            if (iArr2[0] == 0) {
                this.debug.println(1, "No new kernel added.");
                break;
            }
            long currentTimeMillis3 = System.currentTimeMillis();
            trainWeights(hashMap);
            this.debug.println(2, "weights in " + (System.currentTimeMillis() - currentTimeMillis3) + "ms");
            double computePrimalObj2 = computePrimalObj(hashMap);
            this.debug.println(1, "iteration done with obj: " + computePrimalObj2 + " (" + ((computePrimalObj - computePrimalObj2) / computePrimalObj) + ")\tadded: " + iArr2[0] + "\tremoved: " + i5 + "\tsize: " + hashMap.size());
            if ((iArr[0] >= (2 * size) / this.kernelStep || iArr2[0] - i5 == 0) && (computePrimalObj - computePrimalObj2) / computePrimalObj < this.numPrec) {
                this.debug.println(1, "No improvement in obj, nor expanding kernel set, stoping.");
                break;
            }
            computePrimalObj = computePrimalObj2;
        }
        this.debug.println(1, "finished!");
        this.kernelMap = new HashMap(hashMap.size());
        for (Map.Entry<Kernel<T>, Double> entry : hashMap.entrySet()) {
            Kernel<T> key2 = entry.getKey();
            double doubleValue = entry.getValue().doubleValue();
            if (doubleValue != 0.0d) {
                if (key2 instanceof SimpleCacheKernel) {
                    this.kernelMap.put(((SimpleCacheKernel) key2).getKernel(), Double.valueOf(doubleValue));
                } else {
                    this.kernelMap.put(key2, Double.valueOf(doubleValue));
                }
            }
        }
        this.debug.println(2, "kmap: " + this.kernelMap);
        computePrimalObj(hashMap);
        this.kernel = new ThreadedSumKernel(this.kernelMap);
    }

    private void trainWeights(Map<Kernel<T>, Double> map) {
        if (map == null || map.isEmpty() || map.size() == 1) {
            return;
        }
        double[] dArr = new double[this.alpha.length];
        while (true) {
            double computePrimalObj = computePrimalObj(map);
            VectorOperations.prodi(dArr, this.alpha, this.y);
            Kernel<T> kernel = null;
            double d = 0.0d;
            for (Map.Entry<Kernel<T>, Double> entry : map.entrySet()) {
                if (entry.getValue().doubleValue() > d) {
                    kernel = entry.getKey();
                    d = entry.getValue().doubleValue();
                }
            }
            double dot = 0.5d * VectorOperations.dot(dArr, MatrixVectorOperations.rMul(kernel.getKernelMatrix(this.tlist), dArr));
            this.debug.println(3, "gmax: " + dot);
            HashMap hashMap = new HashMap(map.size());
            double d2 = 0.0d;
            double d3 = 0.0d;
            Iterator<Map.Entry<Kernel<T>, Double>> it = map.entrySet().iterator();
            while (it.hasNext()) {
                Kernel<T> key = it.next().getKey();
                double dot2 = 0.5d * VectorOperations.dot(dArr, MatrixVectorOperations.rMul(key.getKernelMatrix(this.tlist), dArr));
                if (dot2 - dot > 0.0d || map.get(key).doubleValue() > 0.0d) {
                    hashMap.put(key, Double.valueOf(dot2 - dot));
                    d3 += dot2 - dot;
                    d2 += Math.abs(dot2 - dot);
                } else {
                    hashMap.put(key, Double.valueOf(0.0d));
                }
            }
            hashMap.put(kernel, Double.valueOf(-d3));
            this.debug.println(3, "new Gradient: " + hashMap);
            this.debug.println(3, "KKT: " + d2 + "\tksum: " + (-d3) + "\t rKKT: " + ((d2 / dot) / hashMap.size()) + "\tgmax: " + dot);
            if (d2 / dot < this.kktPrec) {
                this.debug.println(2, "KKT satisfied (" + (d2 / dot) + "), exiting");
                return;
            }
            HashMap hashMap2 = new HashMap();
            double lineSearch = lineSearch(map, hashMap);
            if (lineSearch < this.numPrec) {
                this.debug.println(3, "no improvement: " + lineSearch);
                return;
            }
            double d4 = 0.0d;
            for (Kernel<T> kernel2 : hashMap.keySet()) {
                double doubleValue = map.containsKey(kernel2) ? 0.0d + map.get(kernel2).doubleValue() : 0.0d;
                if (hashMap.containsKey(kernel2)) {
                    doubleValue += lineSearch * hashMap.get(kernel2).doubleValue();
                }
                if (doubleValue < this.numPrec) {
                    doubleValue = 0.0d;
                }
                d4 += doubleValue;
                hashMap2.put(kernel2, Double.valueOf(doubleValue));
            }
            for (Kernel<T> kernel3 : hashMap2.keySet()) {
                double doubleValue2 = ((Double) hashMap2.get(kernel3)).doubleValue() / d4;
                if (doubleValue2 > this.numPrec) {
                    hashMap2.put(kernel3, Double.valueOf(doubleValue2));
                } else {
                    hashMap2.put(kernel3, Double.valueOf(0.0d));
                }
            }
            map.clear();
            map.putAll(hashMap2);
            double computePrimalObj2 = computePrimalObj(map);
            this.debug.println(3, "step: " + lineSearch + "\t oldObj: " + computePrimalObj + "\t newObj: " + computePrimalObj2);
            if (computePrimalObj2 > computePrimalObj) {
                this.debug.println(2, "!!! rising objective !!!\t oldObj: " + computePrimalObj + "\t newObj: " + computePrimalObj2);
            } else if (computePrimalObj - computePrimalObj2 < this.numPrec) {
                this.debug.println(2, "Line search gave no obj improvement, stoping");
                return;
            }
        }
    }

    /* JADX WARN: Code restructure failed: missing block: B:88:0x03a6, code lost:
    
        r7.debug.println(3, "!!!! non cvx !!! o1: " + r16 + "\to2: " + r29 + "\to4: " + r33 + "\to3: " + r25);
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private double lineSearch(java.util.Map<net.jkernelmachines.kernel.Kernel<T>, java.lang.Double> r8, java.util.Map<net.jkernelmachines.kernel.Kernel<T>, java.lang.Double> r9) {
        /*
            Method dump skipped, instructions count: 1969
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: net.jkernelmachines.classifier.SequentialMKL.lineSearch(java.util.Map, java.util.Map):double");
    }

    private void trainSVM(double[][] dArr) {
        this.alpha = new double[this.tlist.size()];
        this.yEst = new double[this.tlist.size()];
        ArrayList arrayList = new ArrayList(this.tlist.size());
        for (int i = 0; i < this.alpha.length; i++) {
            arrayList.add(Integer.valueOf(i));
        }
        boolean z = true;
        int i2 = 0;
        Random random = new Random(42L);
        while (z) {
            z = false;
            Collections.shuffle(arrayList, random);
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                int intValue = ((Integer) it.next()).intValue();
                double d = 1.0d - (this.y[intValue] * this.yEst[intValue]);
                if (Math.abs(d) >= this.numPrec && (d >= (-this.numPrec) || this.alpha[intValue] > this.numPrec)) {
                    if (d <= this.numPrec || this.alpha[intValue] < this.C - this.numPrec) {
                        z = true;
                        double d2 = this.alpha[intValue] + (d / dArr[intValue][intValue]);
                        if (d2 > this.C - this.numPrec) {
                            d2 = this.C;
                        }
                        if (d2 < this.numPrec) {
                            d2 = 0.0d;
                        }
                        double d3 = d2 - this.alpha[intValue];
                        for (int i3 = 0; i3 < this.tlist.size(); i3++) {
                            double d4 = d3 * this.y[intValue] * dArr[intValue][i3];
                            double[] dArr2 = this.yEst;
                            int i4 = i3;
                            dArr2[i4] = dArr2[i4] + d4;
                        }
                        this.alpha[intValue] = d2;
                    }
                }
            }
            i2++;
            if (i2 > this.E) {
                this.debug.println(4, "Too many iterations in optimizing alpha");
                return;
            }
        }
    }

    private double computePrimalObj(Map<Kernel<T>, Double> map) {
        double[][] kernelMatrix = new ThreadedSumKernel(map).getKernelMatrix(this.tlist);
        trainSVM(kernelMatrix);
        double[] prod = VectorOperations.prod(this.alpha, this.y);
        double[] rMul = MatrixVectorOperations.rMul(kernelMatrix, prod);
        double dot = VectorOperations.dot(prod, rMul);
        double d = 0.0d;
        for (int i = 0; i < rMul.length; i++) {
            double d2 = 1.0d - (this.y[i] * rMul[i]);
            if (d2 > this.numPrec) {
                d += d2;
            }
        }
        return (0.5d * dot) + (this.C * d);
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public double valueOf(T t) {
        double d = 0.0d;
        for (int i = 0; i < this.tlist.size(); i++) {
            d += this.alpha[i] * r0.label * this.kernel.valueOf(this.tlist.get(i).sample, t);
        }
        return d;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    /* renamed from: copy */
    public Classifier<T> copy2() throws CloneNotSupportedException {
        return (SequentialMKL) clone();
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public Kernel<T> getKernel() {
        return this.kernel;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setKernel(Kernel<T> kernel) {
        this.kernel = kernel;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double[] getAlphas() {
        return this.alpha;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setC(double d) {
        this.C = d;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double getC() {
        return this.C;
    }

    @Override // net.jkernelmachines.classifier.MKL
    public double[] getKernelWeights() {
        double[] dArr = new double[this.kernels.size()];
        for (int i = 0; i < this.kernels.size(); i++) {
            Kernel<T> kernel = this.kernels.get(i);
            if (this.kernelMap.containsKey(kernel)) {
                dArr[i] = this.kernelMap.get(kernel).doubleValue();
            }
        }
        return dArr;
    }

    @Override // net.jkernelmachines.classifier.MKL
    public List<Kernel<T>> getKernels() {
        return this.kernels;
    }

    @Override // net.jkernelmachines.classifier.MKL
    public Map<Kernel<T>, Double> getKernelWeightMap() {
        return this.kernelMap;
    }

    @Override // net.jkernelmachines.classifier.MKL
    public void addKernel(Kernel<T> kernel) {
        if (this.kernels == null) {
            this.kernels = new ArrayList();
        }
        this.kernels.add(kernel);
    }
}
