package tech.tablesaw.conversion.smile;

import java.io.IOException;
import java.util.Arrays;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import smile.classification.RandomForest;
import smile.data.AttributeDataset;
import smile.regression.OLS;
import tech.tablesaw.api.Table;
import tech.tablesaw.columns.Column;

/* loaded from: input_file:tech/tablesaw/conversion/smile/SmileConverterTest.class */
public class SmileConverterTest {
    @Test
    public void regression() throws IOException {
        Table csv = Table.read().csv("../data/baseball.csv");
        csv.addColumns(new Column[]{csv.numberColumn("RS").subtract(csv.numberColumn("RA")).setName("RD")});
        Assertions.assertNotNull(new OLS(csv.select(new String[]{"W", "RD"}).smile().numericDataset("RD")).toString());
    }

    @Test
    public void regressionWithStratifiedSampleTest() throws IOException {
        Table csv = Table.read().csv("../data/baseball.csv");
        Table table = csv.stratifiedSampleSplit(csv.stringColumn("Team"), 0.6d)[0];
        table.addColumns(new Column[]{table.numberColumn("RS").subtract(table.numberColumn("RA")).setName("RD")});
        Assertions.assertNotNull(new OLS(table.select(new String[]{"W", "RD"}).smile().numericDataset("RD")).toString());
    }

    @Test
    public void classification() throws IOException {
        Assertions.assertNotNull(new RandomForest(Table.read().csv("../data/baseball.csv").smile().nominalDataset("Playoffs", new String[]{"RS", "RA", "OBP"}), 1).toString());
    }

    @Test
    public void nominalDatasetToString() throws IOException {
        Assertions.assertNotNull(Table.read().csv("../data/baseball.csv").smile().nominalDataset("Playoffs", new String[]{"League", "RS", "RA", "OBP"}).toString());
    }

    @Test
    public void columnNames() throws IOException {
        String[] strArr = {"League", "RS", "RA", "OBP"};
        AttributeDataset nominalDataset = Table.read().csv("../data/baseball.csv").smile().nominalDataset("Playoffs", strArr);
        String[] strArr2 = (String[]) Arrays.stream(nominalDataset.attributes()).map(attribute -> {
            return attribute.getName();
        }).toArray(i -> {
            return new String[i];
        });
        Assertions.assertEquals("Playoffs", nominalDataset.responseAttribute().getName());
        Assertions.assertArrayEquals(strArr, strArr2);
    }
}
