package tech.tablesaw.api.ml.classification;

import com.google.common.base.Preconditions;
import java.util.TreeSet;
import tech.tablesaw.api.IntConvertibleColumn;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.util.DoubleArrays;

/* loaded from: input_file:tech/tablesaw/api/ml/classification/LogisticRegression.class */
public class LogisticRegression extends AbstractClassifier {
    private final smile.classification.LogisticRegression classifierModel;

    private LogisticRegression(smile.classification.LogisticRegression logisticRegression) {
        this.classifierModel = logisticRegression;
    }

    public static LogisticRegression learn(IntConvertibleColumn intConvertibleColumn, NumericColumn... numericColumnArr) {
        return new LogisticRegression(new smile.classification.LogisticRegression(DoubleArrays.to2dArray(numericColumnArr), intConvertibleColumn.asIntArray()));
    }

    public static LogisticRegression learn(IntConvertibleColumn intConvertibleColumn, double d, NumericColumn... numericColumnArr) {
        return new LogisticRegression(new smile.classification.LogisticRegression(DoubleArrays.to2dArray(numericColumnArr), intConvertibleColumn.asIntArray(), d));
    }

    public static LogisticRegression learn(IntConvertibleColumn intConvertibleColumn, double d, double d2, int i, NumericColumn... numericColumnArr) {
        return new LogisticRegression(new smile.classification.LogisticRegression(DoubleArrays.to2dArray(numericColumnArr), intConvertibleColumn.asIntArray(), d, d2, i));
    }

    public int predict(double[] dArr) {
        return this.classifierModel.predict(dArr);
    }

    public ConfusionMatrix predictMatrix(IntConvertibleColumn intConvertibleColumn, NumericColumn... numericColumnArr) {
        Preconditions.checkArgument(numericColumnArr.length > 0);
        StandardConfusionMatrix standardConfusionMatrix = new StandardConfusionMatrix(new TreeSet(intConvertibleColumn.asIntegerSet()));
        populateMatrix(intConvertibleColumn.asIntArray(), standardConfusionMatrix, numericColumnArr);
        return standardConfusionMatrix;
    }

    public int[] predict(NumericColumn... numericColumnArr) {
        Preconditions.checkArgument(numericColumnArr.length > 0);
        int[] iArr = new int[numericColumnArr[0].size()];
        for (int i = 0; i < numericColumnArr[0].size(); i++) {
            double[] dArr = new double[numericColumnArr.length];
            for (NumericColumn numericColumn : numericColumnArr) {
                dArr[i] = numericColumn.getFloat(i);
            }
            iArr[i] = this.classifierModel.predict(dArr);
        }
        return iArr;
    }

    @Override // tech.tablesaw.api.ml.classification.AbstractClassifier
    int predictFromModel(double[] dArr) {
        return this.classifierModel.predict(dArr);
    }

    public double logLikelihood() {
        return this.classifierModel.loglikelihood();
    }

    public double predictFromModel(double[] dArr, double[] dArr2) {
        return this.classifierModel.predict(dArr, dArr2);
    }

    public double predictFromModel(int i, double[] dArr, NumericColumn... numericColumnArr) {
        double[] dArr2 = new double[numericColumnArr.length];
        for (NumericColumn numericColumn : numericColumnArr) {
            dArr2[i] = numericColumn.getFloat(i);
        }
        return this.classifierModel.predict(dArr2, dArr);
    }
}
