package tech.tablesaw.api.ml.classification;

import com.google.common.base.Preconditions;
import java.util.Collection;
import java.util.TreeSet;
import smile.classification.KNN;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.IntConvertibleColumn;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.util.DoubleArrays;

/* loaded from: input_file:tech/tablesaw/api/ml/classification/Knn.class */
public class Knn extends AbstractClassifier {
    private final KNN<double[]> classifierModel;

    private Knn(KNN<double[]> knn) {
        this.classifierModel = knn;
    }

    public static Knn learn(int i, IntConvertibleColumn intConvertibleColumn, NumericColumn... numericColumnArr) {
        return new Knn(KNN.learn(DoubleArrays.to2dArray(numericColumnArr), intConvertibleColumn.asIntArray(), i));
    }

    public int predict(double[] dArr) {
        return this.classifierModel.predict(dArr);
    }

    public ConfusionMatrix predictMatrix(IntConvertibleColumn intConvertibleColumn, NumericColumn... numericColumnArr) {
        Preconditions.checkArgument(numericColumnArr.length > 0);
        StandardConfusionMatrix standardConfusionMatrix = new StandardConfusionMatrix(new TreeSet(intConvertibleColumn.asIntegerSet()));
        populateMatrix(intConvertibleColumn.asIntArray(), standardConfusionMatrix, numericColumnArr);
        return standardConfusionMatrix;
    }

    public ConfusionMatrix predictMatrix(IntColumn intColumn, NumericColumn... numericColumnArr) {
        Preconditions.checkArgument(numericColumnArr.length > 0);
        StandardConfusionMatrix standardConfusionMatrix = new StandardConfusionMatrix(new TreeSet((Collection) intColumn.asSet()));
        populateMatrix(intColumn.data().toIntArray(), standardConfusionMatrix, numericColumnArr);
        return standardConfusionMatrix;
    }

    public int[] predict(NumericColumn... numericColumnArr) {
        Preconditions.checkArgument(numericColumnArr.length > 0);
        int[] iArr = new int[numericColumnArr[0].size()];
        for (int i = 0; i < numericColumnArr[0].size(); i++) {
            double[] dArr = new double[numericColumnArr.length];
            for (NumericColumn numericColumn : numericColumnArr) {
                dArr[i] = numericColumn.getFloat(i);
            }
            iArr[i] = this.classifierModel.predict(dArr);
        }
        return iArr;
    }

    @Override // tech.tablesaw.api.ml.classification.AbstractClassifier
    int predictFromModel(double[] dArr) {
        return this.classifierModel.predict(dArr);
    }
}
