package tech.tablesaw.api.ml.classification;

import com.google.common.base.Preconditions;
import java.util.Collection;
import java.util.TreeSet;
import tech.tablesaw.api.CategoryColumn;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.api.ShortColumn;
import tech.tablesaw.util.DoubleArrays;

/* loaded from: input_file:tech/tablesaw/api/ml/classification/RandomForest.class */
public class RandomForest extends AbstractClassifier {
    private final smile.classification.RandomForest classifierModel;

    private RandomForest(int i, int[] iArr, NumericColumn... numericColumnArr) {
        this.classifierModel = new smile.classification.RandomForest(DoubleArrays.to2dArray(numericColumnArr), iArr, i);
    }

    public static RandomForest learn(int i, IntColumn intColumn, NumericColumn... numericColumnArr) {
        return new RandomForest(i, intColumn.data().toIntArray(), numericColumnArr);
    }

    public static RandomForest learn(int i, ShortColumn shortColumn, NumericColumn... numericColumnArr) {
        return new RandomForest(i, shortColumn.toIntArray(), numericColumnArr);
    }

    public static RandomForest learn(int i, CategoryColumn categoryColumn, NumericColumn... numericColumnArr) {
        return new RandomForest(i, categoryColumn.data().toIntArray(), numericColumnArr);
    }

    public int predict(double[] dArr) {
        return this.classifierModel.predict(dArr);
    }

    public ConfusionMatrix predictMatrix(ShortColumn shortColumn, NumericColumn... numericColumnArr) {
        Preconditions.checkArgument(numericColumnArr.length > 0);
        StandardConfusionMatrix standardConfusionMatrix = new StandardConfusionMatrix(new TreeSet((Collection) shortColumn.asSet()));
        populateMatrix(shortColumn.toIntArray(), standardConfusionMatrix, numericColumnArr);
        return standardConfusionMatrix;
    }

    public ConfusionMatrix predictMatrix(CategoryColumn categoryColumn, NumericColumn... numericColumnArr) {
        Preconditions.checkArgument(numericColumnArr.length > 0);
        CategoryConfusionMatrix categoryConfusionMatrix = new CategoryConfusionMatrix(categoryColumn, new TreeSet(categoryColumn.asSet()));
        populateMatrix(categoryColumn.data().toIntArray(), categoryConfusionMatrix, numericColumnArr);
        return categoryConfusionMatrix;
    }

    @Override // tech.tablesaw.api.ml.classification.AbstractClassifier
    int predictFromModel(double[] dArr) {
        return this.classifierModel.predict(dArr);
    }
}
