package net.jkernelmachines.evaluation;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import net.jkernelmachines.classifier.Classifier;
import net.jkernelmachines.type.TrainingSample;

/* loaded from: input_file:net/jkernelmachines/evaluation/RandomSplitCrossValidation.class */
public class RandomSplitCrossValidation<T> implements CrossValidation, BalancedCrossValidation {
    Classifier<T> classifier;
    Evaluator<T> evaluator;
    double[] results;
    boolean balance = true;
    long seed = 0;
    double trainPercent = 0.7d;
    int nbTest = 20;
    List<TrainingSample<T>> list = new ArrayList();

    public RandomSplitCrossValidation(Classifier<T> classifier, List<TrainingSample<T>> list, Evaluator<T> evaluator) {
        this.classifier = classifier;
        this.list.addAll(list);
        this.evaluator = evaluator;
    }

    @Override // net.jkernelmachines.evaluation.CrossValidation
    public void run() {
        this.results = new double[this.nbTest];
        int size = (int) (this.trainPercent * this.list.size());
        Random random = new Random(this.seed);
        for (int i = this.nbTest; i > 0; i--) {
            Collections.shuffle(this.list, random);
            ArrayList arrayList = new ArrayList(size);
            ArrayList arrayList2 = new ArrayList(this.list.size() - size);
            if (this.balance) {
                ArrayList arrayList3 = new ArrayList();
                ArrayList arrayList4 = new ArrayList();
                for (TrainingSample<T> trainingSample : this.list) {
                    if (trainingSample.label == 1) {
                        arrayList3.add(trainingSample);
                    } else {
                        arrayList4.add(trainingSample);
                    }
                }
                arrayList.addAll(arrayList3.subList(0, (int) (arrayList3.size() * this.trainPercent)));
                arrayList2.addAll(arrayList3);
                arrayList.addAll(arrayList4.subList(0, (int) (arrayList4.size() * this.trainPercent)));
                arrayList2.addAll(arrayList4);
                arrayList2.removeAll(arrayList);
            } else {
                arrayList.addAll(this.list.subList(0, size));
                arrayList2.addAll(this.list);
                arrayList2.removeAll(arrayList);
            }
            this.evaluator.setClassifier(this.classifier);
            this.evaluator.setTrainingSet(arrayList);
            this.evaluator.setTestingSet(arrayList2);
            this.evaluator.evaluate();
            this.results[this.nbTest - i] = this.evaluator.getScore();
        }
    }

    @Override // net.jkernelmachines.evaluation.CrossValidation
    public double getAverageScore() {
        if (this.results == null) {
            return Double.NaN;
        }
        double d = 0.0d;
        for (double d2 : this.results) {
            d += d2;
        }
        return d / this.results.length;
    }

    @Override // net.jkernelmachines.evaluation.CrossValidation
    public double getStdDevScore() {
        if (this.results == null) {
            return Double.NaN;
        }
        double d = 0.0d;
        double averageScore = getAverageScore();
        for (double d2 : this.results) {
            d += (d2 - averageScore) * (d2 - averageScore);
        }
        return Math.sqrt(d / this.results.length);
    }

    @Override // net.jkernelmachines.evaluation.CrossValidation
    public double[] getScores() {
        return this.results;
    }

    public Classifier<T> getClassifier() {
        return this.classifier;
    }

    public void setClassifier(Classifier<T> classifier) {
        this.classifier = classifier;
    }

    public List<TrainingSample<T>> getList() {
        return this.list;
    }

    public void setList(List<TrainingSample<T>> list) {
        this.list = list;
    }

    public double getTrainPercent() {
        return this.trainPercent;
    }

    public void setTrainPercent(double d) {
        this.trainPercent = d;
    }

    public int getNbTest() {
        return this.nbTest;
    }

    public void setNbTest(int i) {
        this.nbTest = i;
    }

    public long getSeed() {
        return this.seed;
    }

    public void setSeed(long j) {
        this.seed = j;
    }

    @Override // net.jkernelmachines.evaluation.BalancedCrossValidation
    public boolean isBalanced() {
        return this.balance;
    }

    @Override // net.jkernelmachines.evaluation.BalancedCrossValidation
    public void setBalanced(boolean z) {
        this.balance = z;
    }
}
