package net.sf.tweety.machinelearning;

import java.util.Iterator;
import net.sf.tweety.machinelearning.Category;
import net.sf.tweety.machinelearning.Observation;

/* loaded from: input_file:net.sf.tweety.machinelearning-1.9.jar:net/sf/tweety/machinelearning/GridSearchParameterLearner.class */
public class GridSearchParameterLearner<S extends Observation, T extends Category> extends ParameterTrainer<S, T> {
    private int depth;
    private int partitions;
    private ClassificationTester<S, T> tester;

    public GridSearchParameterLearner(Trainer<S, T> trainer, ClassificationTester<S, T> classificationTester, int i, int i2) {
        super(trainer);
        this.tester = classificationTester;
        this.depth = i;
        this.partitions = i2;
    }

    @Override // net.sf.tweety.machinelearning.ParameterTrainer
    public ParameterSet learnParameters(TrainingSet<S, T> trainingSet) {
        Trainer<S, T> trainer = getTrainer();
        ParameterSet parameterSet = trainer.getParameterSet();
        int[] iArr = new int[parameterSet.size()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = 0;
        }
        double[] dArr = new double[parameterSet.size()];
        double[] dArr2 = new double[parameterSet.size()];
        int i2 = 0;
        Iterator<TrainingParameter> it = parameterSet.iterator();
        while (it.hasNext()) {
            TrainingParameter next = it.next();
            dArr[i2] = next.getLowerBound();
            int i3 = i2;
            i2++;
            dArr2[i3] = next.getUpperBound();
        }
        double d = 0.0d;
        int[] iArr2 = new int[parameterSet.size()];
        for (int i4 = 0; i4 < this.depth; i4++) {
            do {
                trainer.setParameterSet(adjustParameterSet(parameterSet, iArr, dArr, dArr2));
                double test = this.tester.test(trainer, trainingSet);
                if (test > d) {
                    d = test;
                    System.arraycopy(iArr, 0, iArr2, 0, parameterSet.size());
                }
            } while (!increment(iArr, this.partitions));
            if (i4 + 1 != this.depth) {
                for (int i5 = 0; i5 < parameterSet.size(); i5++) {
                    dArr[i5] = dArr[i5] + (((dArr2[i5] - dArr[i5]) / this.partitions) * Math.max(iArr[i5] - 1, 0));
                    dArr2[i5] = dArr[i5] + (((dArr2[i5] - dArr[i5]) / this.partitions) * Math.min(iArr[i5] + 1, this.partitions));
                    iArr[i5] = 0;
                }
                d = 0.0d;
            }
        }
        return adjustParameterSet(parameterSet, iArr2, dArr, dArr2);
    }

    private ParameterSet adjustParameterSet(ParameterSet parameterSet, int[] iArr, double[] dArr, double[] dArr2) {
        ParameterSet parameterSet2 = new ParameterSet();
        int i = 0;
        Iterator<TrainingParameter> it = parameterSet.iterator();
        while (it.hasNext()) {
            parameterSet2.add(it.next().instantiate(dArr[i] + (((dArr2[i] - dArr[i]) / this.partitions) * iArr[i])));
            i++;
        }
        return parameterSet2;
    }

    private boolean increment(int[] iArr, int i) {
        boolean z = false;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (iArr[i2] < i) {
                int i3 = i2;
                iArr[i3] = iArr[i3] + 1;
                return false;
            }
            iArr[i2] = 0;
            z = true;
        }
        return z;
    }
}
