package net.jkernelmachines.evaluation;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import net.jkernelmachines.classifier.Classifier;
import net.jkernelmachines.type.TrainingSample;
import net.jkernelmachines.util.ArraysUtils;
import net.jkernelmachines.util.DebugPrinter;

/* loaded from: input_file:net/jkernelmachines/evaluation/NFoldCrossValidation.class */
public class NFoldCrossValidation<T> implements CrossValidation, BalancedCrossValidation, MultipleEvaluatorCrossValidation<T> {
    int N;
    Classifier<T> classifier;
    List<TrainingSample<T>> list;
    boolean balanced = true;
    Map<String, Evaluator<T>> evaluators = new HashMap();
    Map<String, double[]> results = new HashMap();
    DebugPrinter debug = new DebugPrinter();

    public NFoldCrossValidation(int i, Classifier<T> classifier, List<TrainingSample<T>> list, Evaluator<T> evaluator) {
        this.N = 5;
        this.N = Math.max(i, 2);
        this.classifier = classifier;
        this.evaluators.put("default", evaluator);
        this.list = new ArrayList();
        this.list.addAll(list);
    }

    @Override // net.jkernelmachines.evaluation.CrossValidation
    public void run() {
        Iterator<String> it = this.evaluators.keySet().iterator();
        while (it.hasNext()) {
            this.results.put(it.next(), new double[this.N]);
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (TrainingSample<T> trainingSample : this.list) {
            if (trainingSample.label == 1) {
                arrayList.add(trainingSample);
            } else {
                arrayList2.add(trainingSample);
            }
        }
        for (int i = 0; i < this.N; i++) {
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            if (this.balanced) {
                int size = arrayList.size() / this.N;
                arrayList3.addAll(arrayList.subList(i * size, (i + 1) * size));
                arrayList4.addAll(arrayList);
                int size2 = arrayList2.size() / this.N;
                arrayList3.addAll(arrayList2.subList(i * size2, (i + 1) * size2));
                arrayList4.addAll(arrayList2);
                arrayList4.removeAll(arrayList3);
            } else {
                int size3 = this.list.size() / this.N;
                arrayList3.addAll(this.list.subList(i * size3, (i + 1) * size3));
                arrayList4.addAll(this.list);
                arrayList4.removeAll(arrayList3);
            }
            this.debug.println(4, "train size: " + arrayList4.size());
            this.debug.println(4, "test size: " + arrayList3.size());
            this.classifier.train(arrayList4);
            for (String str : this.evaluators.keySet()) {
                Evaluator<T> evaluator = this.evaluators.get(str);
                evaluator.setClassifier(this.classifier);
                evaluator.setTrainingSet(null);
                evaluator.setTestingSet(arrayList3);
                evaluator.evaluate();
                this.results.get(str)[i] = evaluator.getScore();
            }
        }
    }

    @Override // net.jkernelmachines.evaluation.CrossValidation
    public double getAverageScore() {
        double[] dArr = this.results.get("default");
        if (dArr == null) {
            return Double.NaN;
        }
        return ArraysUtils.mean(dArr);
    }

    @Override // net.jkernelmachines.evaluation.CrossValidation
    public double getStdDevScore() {
        double[] dArr = this.results.get("default");
        if (dArr == null) {
            return Double.NaN;
        }
        return ArraysUtils.stddev(dArr);
    }

    @Override // net.jkernelmachines.evaluation.CrossValidation
    public double[] getScores() {
        return this.results.get("default");
    }

    @Override // net.jkernelmachines.evaluation.BalancedCrossValidation
    public boolean isBalanced() {
        return this.balanced;
    }

    @Override // net.jkernelmachines.evaluation.BalancedCrossValidation
    public void setBalanced(boolean z) {
        this.balanced = z;
    }

    @Override // net.jkernelmachines.evaluation.MultipleEvaluatorCrossValidation
    public void addEvaluator(String str, Evaluator<T> evaluator) {
        this.evaluators.put(str, evaluator);
    }

    @Override // net.jkernelmachines.evaluation.MultipleEvaluatorCrossValidation
    public void removeEvaluator(String str) {
        if (this.evaluators.containsKey(str)) {
            this.evaluators.remove(str);
        }
    }

    @Override // net.jkernelmachines.evaluation.MultipleEvaluatorCrossValidation
    public double getAverageScore(String str) {
        double[] dArr = this.results.get(str);
        if (dArr == null) {
            return Double.NaN;
        }
        return ArraysUtils.mean(dArr);
    }

    @Override // net.jkernelmachines.evaluation.MultipleEvaluatorCrossValidation
    public double getStdDevScore(String str) {
        double[] dArr = this.results.get(str);
        if (dArr == null) {
            return Double.NaN;
        }
        return ArraysUtils.stddev(dArr);
    }

    @Override // net.jkernelmachines.evaluation.MultipleEvaluatorCrossValidation
    public double[] getScores(String str) {
        return this.results.get(str);
    }
}
