package net.jkernelmachines.classifier;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import net.jkernelmachines.kernel.Kernel;
import net.jkernelmachines.kernel.typed.DoubleLinear;
import net.jkernelmachines.kernel.typed.GeneralizedDoubleGaussL2;
import net.jkernelmachines.threading.ThreadPoolServer;
import net.jkernelmachines.threading.ThreadedMatrixOperator;
import net.jkernelmachines.type.TrainingSample;
import net.jkernelmachines.util.DebugPrinter;

/* loaded from: input_file:net/jkernelmachines/classifier/DoubleQNPKL.class */
public class DoubleQNPKL implements KernelSVM<double[]>, Serializable {
    private static final long serialVersionUID = -5475712590325368437L;
    LaSVM<double[]> svm;
    double oldObjective;
    double[] diffWeights;
    double[] g;
    double[] weights;
    double[] B;
    int dim = 0;
    transient DebugPrinter debug = new DebugPrinter();
    DoubleLinear linear = new DoubleLinear();
    double stopGap = 1.0E-7d;
    double num_cleaning = 1.0E-7d;
    double p_norm = 1.0d;
    boolean hasNorm = false;
    double C = 1.0d;
    double d_lambda = 0.1d;
    double[][] lambda_matrix = (double[][]) null;
    List<Double> listOfKernelWeights = new ArrayList();
    List<TrainingSample<double[]>> listOfExamples = new ArrayList();
    List<Double> listOfExampleWeights = new ArrayList();

    /* loaded from: input_file:net/jkernelmachines/classifier/DoubleQNPKL$GradMAtrixOperator.class */
    class GradMAtrixOperator extends ThreadedMatrixOperator {
        double[] grad;
        Map<Integer, double[][]> matrices = new HashMap();

        GradMAtrixOperator() {
        }

        public void addMatrix(int i, double[][] dArr) {
            this.matrices.put(Integer.valueOf(i), dArr);
        }

        public void setGrad(double[] dArr) {
            this.grad = dArr;
        }

        public void clearMatrices() {
            this.matrices.clear();
        }

        @Override // net.jkernelmachines.threading.ThreadedMatrixOperator
        public void doLines(double[][] dArr, int i, int i2) {
            HashMap hashMap = new HashMap();
            Iterator<Integer> it = this.matrices.keySet().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                double d = 0.0d;
                double[][] dArr2 = this.matrices.get(Integer.valueOf(intValue));
                for (int i3 = i; i3 < i2; i3++) {
                    if (DoubleQNPKL.this.lambda_matrix[i3] != null) {
                        for (int i4 = 0; i4 < dArr.length; i4++) {
                            d += dArr2[i3][i4] * DoubleQNPKL.this.lambda_matrix[i3][i4];
                        }
                    }
                }
                hashMap.put(Integer.valueOf(intValue), Double.valueOf(d));
            }
            synchronized (this.grad) {
                Iterator it2 = hashMap.keySet().iterator();
                while (it2.hasNext()) {
                    int intValue2 = ((Integer) it2.next()).intValue();
                    double[] dArr3 = this.grad;
                    dArr3[intValue2] = dArr3[intValue2] + (0.5d * ((Double) hashMap.get(Integer.valueOf(intValue2))).doubleValue());
                }
            }
        }
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public void train(List<TrainingSample<double[]>> list) {
        double d;
        long currentTimeMillis = System.currentTimeMillis();
        this.dim = list.get(0).sample.length;
        this.debug.println(2, "training on " + this.dim + " kernels and " + list.size() + " examples");
        this.listOfExamples = new ArrayList();
        this.listOfExamples.addAll(list);
        this.weights = new double[this.dim];
        for (int i = 0; i < this.dim; i++) {
            this.weights[i] = 1.0d / this.dim;
        }
        GeneralizedDoubleGaussL2 generalizedDoubleGaussL2 = new GeneralizedDoubleGaussL2(this.weights);
        this.svm = trainSVM(generalizedDoubleGaussL2);
        double[] alphas = this.svm.getAlphas();
        updateLambdaMatrix(alphas, generalizedDoubleGaussL2);
        this.oldObjective = computeObj(alphas);
        this.debug.println(2, "+ initial objective : " + this.oldObjective);
        this.debug.println(3, "+ initial weights : " + Arrays.toString(this.weights));
        do {
            double performPKLStep = performPKLStep();
            if (performPKLStep < 0.0d) {
                this.debug.println(1, "Error, performPKLStep return wrong value");
                System.exit(0);
            }
            d = 1.0d - performPKLStep;
            this.debug.println(1, "+ objective_gap : " + ((float) d));
            this.debug.println(1, "+");
        } while (d >= this.stopGap);
        this.listOfKernelWeights = new ArrayList();
        for (int i2 = 0; i2 < this.weights.length; i2++) {
            this.listOfKernelWeights.add(Double.valueOf(this.weights[i2]));
        }
        GeneralizedDoubleGaussL2 generalizedDoubleGaussL22 = new GeneralizedDoubleGaussL2(this.weights);
        this.svm = trainSVM(generalizedDoubleGaussL22);
        double[] alphas2 = this.svm.getAlphas();
        updateLambdaMatrix(alphas2, generalizedDoubleGaussL22);
        this.debug.println(2, "+ final objective : " + computeObj(alphas2));
        this.listOfExamples.addAll(list);
        this.listOfExampleWeights.clear();
        for (double d2 : this.svm.getAlphas()) {
            this.listOfExampleWeights.add(Double.valueOf(d2));
        }
        this.debug.println(3, "kernel weights : " + this.listOfKernelWeights);
        this.debug.println(1, "PKL trained in " + (System.currentTimeMillis() - currentTimeMillis) + " milis.");
    }

    private double performPKLStep() {
        double d = this.oldObjective;
        double[] dArr = this.weights;
        GeneralizedDoubleGaussL2 generalizedDoubleGaussL2 = new GeneralizedDoubleGaussL2(this.weights);
        updateLambdaMatrix(trainSVM(generalizedDoubleGaussL2).getAlphas(), generalizedDoubleGaussL2);
        double[] computeGrad = computeGrad(generalizedDoubleGaussL2);
        double[] computeB = computeB(computeGrad);
        double d2 = 1.0d;
        do {
            double[] dArr2 = new double[this.weights.length];
            double d3 = 0.0d;
            for (int i = 0; i < dArr2.length; i++) {
                dArr2[i] = this.weights[i] - ((d2 * computeB[i]) * computeGrad[i]);
                if (dArr2[i] < this.num_cleaning) {
                    dArr2[i] = 0.0d;
                }
                if (this.hasNorm) {
                    d3 += dArr2[i];
                }
            }
            if (this.hasNorm) {
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] / d3;
                }
            }
            GeneralizedDoubleGaussL2 generalizedDoubleGaussL22 = new GeneralizedDoubleGaussL2(dArr2);
            double[] alphas = trainSVM(generalizedDoubleGaussL22).getAlphas();
            updateLambdaMatrix(alphas, generalizedDoubleGaussL22);
            double computeObj = computeObj(alphas);
            if (computeObj >= d) {
                this.debug.print(3, ".");
                d2 /= 10.0d;
            } else {
                d = computeObj;
                this.weights = dArr2;
            }
            this.debug.println(3, "");
        } while (d2 > this.num_cleaning);
        this.g = computeGrad;
        if (this.diffWeights == null) {
            this.diffWeights = new double[this.weights.length];
        }
        for (int i4 = 0; i4 < this.weights.length; i4++) {
            this.diffWeights[i4] = this.weights[i4] - dArr[i4];
        }
        this.debug.println(3, "++++++ w : " + Arrays.toString(this.weights));
        double d4 = d / this.oldObjective;
        this.oldObjective = d;
        return d4;
    }

    private double[] computeGrad(GeneralizedDoubleGaussL2 generalizedDoubleGaussL2) {
        this.debug.print(3, "++++++ g : ");
        double[] dArr = new double[this.dim];
        ThreadPoolExecutor threadPoolExecutor = ThreadPoolServer.getThreadPoolExecutor();
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < dArr.length; i++) {
            linkedList.add(threadPoolExecutor.submit(new Runnable(generalizedDoubleGaussL2, i, dArr) { // from class: net.jkernelmachines.classifier.DoubleQNPKL.1GradRunnable
                GeneralizedDoubleGaussL2 kernel;
                int i;
                final /* synthetic */ double[] val$grad;

                {
                    this.val$grad = dArr;
                    this.kernel = generalizedDoubleGaussL2;
                    this.i = i;
                }

                @Override // java.lang.Runnable
                public void run() {
                    double[][] distanceMatrixUnthreaded = this.kernel.distanceMatrixUnthreaded(DoubleQNPKL.this.listOfExamples, this.i);
                    double d = 0.0d;
                    for (int i2 = 0; i2 < distanceMatrixUnthreaded.length; i2++) {
                        if (DoubleQNPKL.this.lambda_matrix[i2] != null) {
                            for (int i3 = 0; i3 < distanceMatrixUnthreaded.length; i3++) {
                                d += distanceMatrixUnthreaded[i2][i3] * DoubleQNPKL.this.lambda_matrix[i2][i3];
                            }
                        }
                    }
                    double[] dArr2 = this.val$grad;
                    int i4 = this.i;
                    dArr2[i4] = dArr2[i4] + (0.5d * d);
                }
            }));
        }
        while (!linkedList.isEmpty()) {
            try {
                ((Future) linkedList.remove()).get();
            } catch (Exception e) {
                System.err.println("error with grad :");
                e.printStackTrace();
            }
        }
        ThreadPoolServer.shutdownNow(threadPoolExecutor);
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (Math.abs(dArr[i2]) < this.num_cleaning) {
                dArr[i2] = 0.0d;
            }
        }
        this.debug.println(3, Arrays.toString(dArr));
        return dArr;
    }

    private double[] computeB(double[] dArr) {
        if (this.B == null || this.diffWeights == null) {
            this.B = new double[dArr.length];
            Arrays.fill(this.B, this.d_lambda);
        } else {
            for (int i = 0; i < this.g.length; i++) {
                double d = dArr[i] - this.g[i];
                this.debug.print(3, " gn-g : " + d + " wn-w : " + this.diffWeights[i]);
                if (d != 0.0d) {
                    d = this.diffWeights[i] / d;
                }
                this.debug.println(3, " b : " + d);
                this.B[i] = Math.max(this.num_cleaning, d);
            }
        }
        this.debug.println(3, "++++++ B : " + Arrays.toString(this.B));
        return this.B;
    }

    private double computeObj(double[] dArr) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (double d3 : dArr) {
            d += d3;
        }
        for (int i = 0; i < this.lambda_matrix.length; i++) {
            if (this.lambda_matrix[i] != null) {
                for (int i2 = 0; i2 < this.lambda_matrix.length; i2++) {
                    if (this.lambda_matrix[i][i2] != 0.0d) {
                        d2 += this.lambda_matrix[i][i2];
                    }
                }
            }
        }
        double d4 = d - (0.5d * d2);
        this.debug.println(3, "+++ obj : " + d4 + "\t(obj1 : " + d + " obj3 : " + ((-0.5d) * d2) + ")");
        return d4;
    }

    private void updateLambdaMatrix(final double[] dArr, GeneralizedDoubleGaussL2 generalizedDoubleGaussL2) {
        final double[][] kernelMatrix = generalizedDoubleGaussL2.getKernelMatrix(this.listOfExamples);
        if (this.lambda_matrix == null) {
            this.lambda_matrix = new double[kernelMatrix.length][kernelMatrix.length];
        }
        this.debug.println(3, "+ update lambda matrix");
        this.lambda_matrix = new ThreadedMatrixOperator() { // from class: net.jkernelmachines.classifier.DoubleQNPKL.1
            @Override // net.jkernelmachines.threading.ThreadedMatrixOperator
            public void doLines(double[][] dArr2, int i, int i2) {
                for (int i3 = i; i3 < i2; i3++) {
                    if (dArr[i3] == 0.0d) {
                        dArr2[i3] = null;
                    } else {
                        if (dArr2[i3] == null) {
                            dArr2[i3] = new double[kernelMatrix.length];
                        }
                        double d = dArr[i3] * DoubleQNPKL.this.listOfExamples.get(i3).label;
                        for (int i4 = 0; i4 < dArr2[i3].length; i4++) {
                            dArr2[i3][i4] = d * DoubleQNPKL.this.listOfExamples.get(i4).label * dArr[i4] * kernelMatrix[i3][i4];
                        }
                    }
                }
            }
        }.getMatrix(this.lambda_matrix);
    }

    private LaSVM<double[]> trainSVM(GeneralizedDoubleGaussL2 generalizedDoubleGaussL2) {
        LaSVM<double[]> laSVM = new LaSVM<>(generalizedDoubleGaussL2);
        laSVM.setC(this.C);
        laSVM.setE(10);
        this.debug.println(3, "+ training svm");
        laSVM.train(this.listOfExamples);
        return laSVM;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    public double valueOf(double[] dArr) {
        return this.svm.valueOf(dArr);
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double getC() {
        return this.C;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setC(double d) {
        this.C = d;
    }

    public void setPNorm(double d) {
        this.p_norm = d;
    }

    public void setStopGap(double d) {
        this.stopGap = d;
    }

    public double getNum_cleaning() {
        return this.num_cleaning;
    }

    public void setNum_cleaning(double d) {
        this.num_cleaning = d;
    }

    public List<Double> getExampleWeights() {
        return this.listOfExampleWeights;
    }

    public List<Double> getListOfKernelWeights() {
        return this.listOfKernelWeights;
    }

    public double[] getKernelWeights() {
        double[] dArr = new double[this.listOfKernelWeights.size()];
        for (int i = 0; i < this.listOfKernelWeights.size(); i++) {
            dArr[i] = this.listOfKernelWeights.get(i).doubleValue();
        }
        return dArr;
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public double[] getAlphas() {
        return this.svm.getAlphas();
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public void setKernel(Kernel<double[]> kernel) {
    }

    @Override // net.jkernelmachines.classifier.KernelSVM
    public Kernel<double[]> getKernel() {
        return this.svm.getKernel();
    }

    public boolean isHasNorm() {
        return this.hasNorm;
    }

    public void setHasNorm(boolean z) {
        this.hasNorm = z;
    }

    @Override // net.jkernelmachines.classifier.Classifier
    /* renamed from: copy */
    public DoubleQNPKL copy2() throws CloneNotSupportedException {
        return (DoubleQNPKL) super.clone();
    }

    public double getPNorm() {
        return this.p_norm;
    }

    public double getStopGap() {
        return this.stopGap;
    }
}
