package info.vizierdb.commands.mimir.imputation;

import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.classification.GBTClassifier;
import org.apache.spark.ml.classification.LinearSVC;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.ml.classification.NaiveBayes;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.Normalizer;
import org.apache.spark.ml.feature.RegexTokenizer;
import org.apache.spark.ml.feature.StandardScaler;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.NumericType;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import scala.Array$;
import scala.Enumeration;
import scala.Function3;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.immutable.$colon;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

/* compiled from: MulticlassImputer.scala */
/* loaded from: input_file:info/vizierdb/commands/mimir/imputation/MulticlassImputer$.class */
public final class MulticlassImputer$ implements Serializable {
    public static MulticlassImputer$ MODULE$;
    private final String PREDICTED_LABEL_COL;
    private final Map<String, Function3<Dataset<Row>, String, Enumeration.Value, PipelineModel>> classifierPipelines;

    static {
        new MulticlassImputer$();
    }

    public String PREDICTED_LABEL_COL() {
        return this.PREDICTED_LABEL_COL;
    }

    private Tuple2<String[], Seq<PipelineStage>> extractFeatures(Dataset<Row> dataset, String str, Enumeration.Value value) {
        StructField[] structFieldArr = (StructField[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(dataset.schema().fields())).filterNot(structField -> {
            return BoxesRunTime.boxToBoolean($anonfun$extractFeatures$1(str, structField));
        });
        Dataset withColumn = dataset.withColumn(str, dataset.apply(str).cast(StringType$.MODULE$));
        CastForStringIndex outputCol = new CastForStringIndex().setInputCol(str).setOutputCol(str);
        StringIndexer handleInvalid = new StringIndexer().setInputCol(str).setOutputCol("label").setHandleInvalid(value.toString());
        String[] labels = handleInvalid.fit(withColumn).labels();
        Tuple2 unzip = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(structFieldArr)).flatMap(structField2 -> {
            Iterable option2Iterable;
            if (StringType$.MODULE$.equals(structField2.dataType())) {
                RegexTokenizer outputCol2 = new RegexTokenizer().setInputCol(structField2.name()).setOutputCol(new StringBuilder(6).append(structField2.name()).append("_words").toString());
                option2Iterable = Option$.MODULE$.option2Iterable(new Some(new Tuple2(outputCol2, new HashingTF().setInputCol(outputCol2.getOutputCol()).setOutputCol(new StringBuilder(9).append(structField2.name()).append("_features").toString()).setNumFeatures(20))));
            } else {
                option2Iterable = Option$.MODULE$.option2Iterable(None$.MODULE$);
            }
            return option2Iterable;
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).unzip(Predef$.MODULE$.$conforms(), ClassTag$.MODULE$.apply(RegexTokenizer.class), ClassTag$.MODULE$.apply(HashingTF.class));
        if (unzip == null) {
            throw new MatchError(unzip);
        }
        Tuple2 tuple2 = new Tuple2((RegexTokenizer[]) unzip._1(), (HashingTF[]) unzip._2());
        RegexTokenizer[] regexTokenizerArr = (RegexTokenizer[]) tuple2._1();
        HashingTF[] hashingTFArr = (HashingTF[]) tuple2._2();
        ReplaceNullsForColumn[] replaceNullsForColumnArr = (ReplaceNullsForColumn[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(structFieldArr)).flatMap(structField3 -> {
            DataType dataType = structField3.dataType();
            return StringType$.MODULE$.equals(dataType) ? Option$.MODULE$.option2Iterable(new Some(new ReplaceNullsForColumn().setInputColumn(structField3.name()).setOutputColumn(structField3.name()).setReplacementColumn("''"))) : dataType instanceof NumericType ? Option$.MODULE$.option2Iterable(new Some(new ReplaceNullsForColumn().setInputColumn(structField3.name()).setOutputColumn(structField3.name()).setReplacementColumn("0"))) : Option$.MODULE$.option2Iterable(None$.MODULE$);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ReplaceNullsForColumn.class)));
        VectorAssembler handleInvalid2 = new VectorAssembler().setInputCols((String[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((String[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(structFieldArr)).flatMap(structField4 -> {
            DataType dataType = structField4.dataType();
            return StringType$.MODULE$.equals(dataType) ? Option$.MODULE$.option2Iterable(new Some(new StringBuilder(9).append(structField4.name()).append("_features").toString())) : dataType instanceof NumericType ? Option$.MODULE$.option2Iterable(new Some(structField4.name())) : Option$.MODULE$.option2Iterable(None$.MODULE$);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))))).toArray(ClassTag$.MODULE$.apply(String.class))).setOutputCol("rawFeatures").setHandleInvalid(value.toString());
        Normalizer p = new Normalizer().setInputCol("rawFeatures").setOutputCol("normFeatures").setP(1.0d);
        StandardScaler withMean = new StandardScaler().setInputCol("normFeatures").setOutputCol("features").setWithStd(true).setWithMean(false);
        return new Tuple2<>(labels, ((List) ((TraversableLike) ((TraversableLike) Nil$.MODULE$.$colon$colon(withMean).$colon$colon(new StagePrinter("norm")).$colon$colon(p).$colon$colon(new StagePrinter("assembler")).$colon$colon(handleInvalid2).$colon$colon(new StagePrinter("tokhash")).$plus$plus$colon(Predef$.MODULE$.wrapRefArray(hashingTFArr), List$.MODULE$.canBuildFrom())).$plus$plus$colon(Predef$.MODULE$.wrapRefArray(regexTokenizerArr), List$.MODULE$.canBuildFrom())).$plus$plus$colon(Predef$.MODULE$.wrapRefArray(replaceNullsForColumnArr), List$.MODULE$.canBuildFrom())).$colon$colon(new StagePrinter("indexer")).$colon$colon(handleInvalid).$colon$colon(new StagePrinter("indexcast")).$colon$colon(outputCol));
    }

    public Map<String, Function3<Dataset<Row>, String, Enumeration.Value, PipelineModel>> classifierPipelines() {
        return this.classifierPipelines;
    }

    public MulticlassImputer apply(String str, String str2) {
        return new MulticlassImputer(str, str2);
    }

    public Option<Tuple2<String, String>> unapply(MulticlassImputer multiclassImputer) {
        return multiclassImputer == null ? None$.MODULE$ : new Some(new Tuple2(multiclassImputer.imputeCol(), multiclassImputer.strategy()));
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ boolean $anonfun$extractFeatures$1(String str, StructField structField) {
        return structField.name().equalsIgnoreCase(str);
    }

    public static final /* synthetic */ boolean $anonfun$classifierPipelines$2(StructField structField) {
        return structField.dataType() instanceof NumericType;
    }

    private MulticlassImputer$() {
        MODULE$ = this;
        this.PREDICTED_LABEL_COL = "predictedLabel";
        this.classifierPipelines = Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{new Tuple2("NaiveBayes", (dataset, str, value) -> {
            Dataset drop = dataset.na().drop();
            Dataset<Row> dataset = (Dataset) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(drop.schema().fields())).filter(structField -> {
                return BoxesRunTime.boxToBoolean($anonfun$classifierPipelines$2(structField));
            }))).foldLeft(drop, (dataset2, structField2) -> {
                return dataset2.withColumn(structField2.name(), functions$.MODULE$.abs(dataset2.apply(structField2.name())));
            });
            Tuple2<String[], Seq<PipelineStage>> extractFeatures = MODULE$.extractFeatures(dataset, str, value);
            if (extractFeatures == null) {
                throw new MatchError(extractFeatures);
            }
            Tuple2 tuple2 = new Tuple2((String[]) extractFeatures._1(), (Seq) extractFeatures._2());
            String[] strArr = (String[]) tuple2._1();
            Seq seq = (Seq) tuple2._2();
            NaiveBayes featuresCol = new NaiveBayes().setLabelCol("label").setFeaturesCol("features");
            return new Pipeline().setStages((PipelineStage[]) ((Seq) new $colon.colon(new StagePrinter("features"), new $colon.colon(featuresCol, new $colon.colon(new StagePrinter("classifier"), new $colon.colon(new IndexToString().setInputCol(featuresCol.getPredictionCol()).setOutputCol(MODULE$.PREDICTED_LABEL_COL()).setLabels(strArr), new $colon.colon(new StagePrinter("labelconvert"), new $colon.colon(new ReplaceNullsForColumn().setInputColumn(str).setOutputColumn(str).setReplacementColumn(MODULE$.PREDICTED_LABEL_COL()), Nil$.MODULE$)))))).$plus$plus$colon(seq, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(PipelineStage.class))).fit(dataset);
        }), new Tuple2("RandomForest", (dataset2, str2, value2) -> {
            Dataset<Row> drop = dataset2.na().drop();
            Tuple2<String[], Seq<PipelineStage>> extractFeatures = MODULE$.extractFeatures(drop, str2, value2);
            if (extractFeatures == null) {
                throw new MatchError(extractFeatures);
            }
            Tuple2 tuple2 = new Tuple2((String[]) extractFeatures._1(), (Seq) extractFeatures._2());
            String[] strArr = (String[]) tuple2._1();
            Seq seq = (Seq) tuple2._2();
            RandomForestClassifier featuresCol = new RandomForestClassifier().setLabelCol("label").setFeaturesCol("features");
            return new Pipeline().setStages((PipelineStage[]) ((List) Nil$.MODULE$.$colon$colon(new IndexToString().setInputCol(featuresCol.getPredictionCol()).setOutputCol(MODULE$.PREDICTED_LABEL_COL()).setLabels(strArr)).$colon$colon(featuresCol).$plus$plus$colon(seq, List$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(PipelineStage.class))).fit(drop);
        }), new Tuple2("DecisionTree", (dataset3, str3, value3) -> {
            Dataset<Row> drop = dataset3.na().drop();
            Tuple2<String[], Seq<PipelineStage>> extractFeatures = MODULE$.extractFeatures(drop, str3, value3);
            if (extractFeatures == null) {
                throw new MatchError(extractFeatures);
            }
            Tuple2 tuple2 = new Tuple2((String[]) extractFeatures._1(), (Seq) extractFeatures._2());
            String[] strArr = (String[]) tuple2._1();
            Seq seq = (Seq) tuple2._2();
            DecisionTreeClassifier featuresCol = new DecisionTreeClassifier().setLabelCol("label").setFeaturesCol("features");
            return new Pipeline().setStages((PipelineStage[]) ((List) Nil$.MODULE$.$colon$colon(new IndexToString().setInputCol(featuresCol.getPredictionCol()).setOutputCol(MODULE$.PREDICTED_LABEL_COL()).setLabels(strArr)).$colon$colon(featuresCol).$plus$plus$colon(seq, List$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(PipelineStage.class))).fit(drop);
        }), new Tuple2("GradientBoostedTreeBinary", (dataset4, str4, value4) -> {
            Dataset<Row> drop = dataset4.na().drop();
            Tuple2<String[], Seq<PipelineStage>> extractFeatures = MODULE$.extractFeatures(drop, str4, value4);
            if (extractFeatures == null) {
                throw new MatchError(extractFeatures);
            }
            Tuple2 tuple2 = new Tuple2((String[]) extractFeatures._1(), (Seq) extractFeatures._2());
            String[] strArr = (String[]) tuple2._1();
            Seq seq = (Seq) tuple2._2();
            VectorIndexer maxCategories = new VectorIndexer().setInputCol("assembledFeatures").setOutputCol("features").setMaxCategories(20);
            GBTClassifier maxIter = new GBTClassifier().setLabelCol("label").setFeaturesCol("features").setMaxIter(10);
            return new Pipeline().setStages((PipelineStage[]) ((List) Nil$.MODULE$.$colon$colon(new IndexToString().setInputCol(maxIter.getPredictionCol()).setOutputCol(MODULE$.PREDICTED_LABEL_COL()).setLabels(strArr)).$colon$colon(maxIter).$colon$colon(maxCategories).$plus$plus$colon(seq, List$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(PipelineStage.class))).fit(drop.withColumn(str4, drop.apply(str4).cast(StringType$.MODULE$)));
        }), new Tuple2("LogisticRegression", (dataset5, str5, value5) -> {
            Dataset<Row> drop = dataset5.na().drop();
            Tuple2<String[], Seq<PipelineStage>> extractFeatures = MODULE$.extractFeatures(drop, str5, value5);
            if (extractFeatures == null) {
                throw new MatchError(extractFeatures);
            }
            Tuple2 tuple2 = new Tuple2((String[]) extractFeatures._1(), (Seq) extractFeatures._2());
            String[] strArr = (String[]) tuple2._1();
            Seq seq = (Seq) tuple2._2();
            LogisticRegression featuresCol = new LogisticRegression().setMaxIter(10).setTol(1.0E-6d).setFitIntercept(true).setLabelCol("label").setFeaturesCol("features");
            return new Pipeline().setStages((PipelineStage[]) ((List) Nil$.MODULE$.$colon$colon(new IndexToString().setInputCol(featuresCol.getPredictionCol()).setOutputCol(MODULE$.PREDICTED_LABEL_COL()).setLabels(strArr)).$colon$colon(featuresCol).$plus$plus$colon(seq, List$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(PipelineStage.class))).fit(drop);
        }), new Tuple2("OneVsRest", (dataset6, str6, value6) -> {
            Dataset<Row> drop = dataset6.na().drop();
            Tuple2<String[], Seq<PipelineStage>> extractFeatures = MODULE$.extractFeatures(drop, str6, value6);
            if (extractFeatures == null) {
                throw new MatchError(extractFeatures);
            }
            Tuple2 tuple2 = new Tuple2((String[]) extractFeatures._1(), (Seq) extractFeatures._2());
            String[] strArr = (String[]) tuple2._1();
            Seq seq = (Seq) tuple2._2();
            LogisticRegression featuresCol = new LogisticRegression().setMaxIter(10).setTol(1.0E-6d).setFitIntercept(true).setLabelCol("label").setFeaturesCol("features");
            return new Pipeline().setStages((PipelineStage[]) ((List) Nil$.MODULE$.$colon$colon(new IndexToString().setInputCol(featuresCol.getPredictionCol()).setOutputCol(MODULE$.PREDICTED_LABEL_COL()).setLabels(strArr)).$colon$colon(new OneVsRest().setClassifier(featuresCol)).$plus$plus$colon(seq, List$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(PipelineStage.class))).fit(drop);
        }), new Tuple2("LinearSupportVectorMachineBinary", (dataset7, str7, value7) -> {
            Dataset<Row> drop = dataset7.na().drop();
            Tuple2<String[], Seq<PipelineStage>> extractFeatures = MODULE$.extractFeatures(drop, str7, value7);
            if (extractFeatures == null) {
                throw new MatchError(extractFeatures);
            }
            Tuple2 tuple2 = new Tuple2((String[]) extractFeatures._1(), (Seq) extractFeatures._2());
            String[] strArr = (String[]) tuple2._1();
            Seq seq = (Seq) tuple2._2();
            LinearSVC featuresCol = new LinearSVC().setMaxIter(10).setRegParam(0.1d).setLabelCol("label").setFeaturesCol("features");
            return new Pipeline().setStages((PipelineStage[]) ((List) Nil$.MODULE$.$colon$colon(new IndexToString().setInputCol(featuresCol.getPredictionCol()).setOutputCol(MODULE$.PREDICTED_LABEL_COL()).setLabels(strArr)).$colon$colon(featuresCol).$plus$plus$colon(seq, List$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(PipelineStage.class))).fit(drop);
        }), new Tuple2("MultilayerPerceptron", (dataset8, str8, value8) -> {
            Dataset<Row> drop = dataset8.na().drop();
            Tuple2<String[], Seq<PipelineStage>> extractFeatures = MODULE$.extractFeatures(drop, str8, value8);
            if (extractFeatures == null) {
                throw new MatchError(extractFeatures);
            }
            Tuple2 tuple2 = new Tuple2((String[]) extractFeatures._1(), (Seq) extractFeatures._2());
            String[] strArr = (String[]) tuple2._1();
            Seq seq = (Seq) tuple2._2();
            MultilayerPerceptronClassifier featuresCol = new MultilayerPerceptronClassifier().setLayers((int[]) Array$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{drop.columns().length, 8, 4, (int) ((Row) drop.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.countDistinct(functions$.MODULE$.col(str8), Predef$.MODULE$.wrapRefArray(new Column[0]))})).head()).getLong(0)}), ClassTag$.MODULE$.Int())).setBlockSize(128).setSeed(1234L).setMaxIter(100).setLabelCol("label").setFeaturesCol("features");
            return new Pipeline().setStages((PipelineStage[]) ((List) Nil$.MODULE$.$colon$colon(new IndexToString().setInputCol(featuresCol.getPredictionCol()).setOutputCol(MODULE$.PREDICTED_LABEL_COL()).setLabels(strArr)).$colon$colon(featuresCol).$plus$plus$colon(seq, List$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(PipelineStage.class))).fit(drop);
        })}));
    }
}
