package org.jpmml.sparkml.feature;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.spark.ml.feature.ImputerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.dmg.pmml.DataField;
import org.dmg.pmml.MissingValueTreatmentMethod;
import org.jpmml.converter.Feature;
import org.jpmml.converter.MissingValueDecorator;
import org.jpmml.converter.ValueUtil;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;

/* loaded from: input_file:org/jpmml/sparkml/feature/ImputerModelConverter.class */
public class ImputerModelConverter extends FeatureConverter<ImputerModel> {
    public ImputerModelConverter(ImputerModel imputerModel) {
        super(imputerModel);
    }

    @Override // org.jpmml.sparkml.FeatureConverter
    public List<Feature> encodeFeatures(SparkMLEncoder sparkMLEncoder) {
        ImputerModel imputerModel = (ImputerModel) getTransformer();
        Double valueOf = Double.valueOf(imputerModel.getMissingValue());
        String strategy = imputerModel.getStrategy();
        Dataset surrogateDF = imputerModel.surrogateDF();
        String[] inputCols = imputerModel.getInputCols();
        String[] outputCols = imputerModel.getOutputCols();
        if (inputCols.length != outputCols.length) {
            throw new IllegalArgumentException();
        }
        MissingValueTreatmentMethod parseStrategy = parseStrategy(strategy);
        List collectAsList = surrogateDF.collectAsList();
        if (collectAsList.size() != 1) {
            throw new IllegalArgumentException();
        }
        Row row = (Row) collectAsList.get(0);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < inputCols.length; i++) {
            String str = inputCols[i];
            String str2 = outputCols[i];
            Feature onlyFeature = sparkMLEncoder.getOnlyFeature(str);
            if (!(sparkMLEncoder.getField(onlyFeature.getName()) instanceof DataField)) {
                throw new IllegalArgumentException();
            }
            MissingValueDecorator missingValueTreatment = new MissingValueDecorator().setMissingValueReplacement(ValueUtil.formatValue(row.getAs(str))).setMissingValueTreatment(parseStrategy);
            if (valueOf != null && !valueOf.isNaN()) {
                missingValueTreatment.addValues(new String[]{ValueUtil.formatValue(valueOf)});
            }
            sparkMLEncoder.addDecorator(onlyFeature.getName(), missingValueTreatment);
            arrayList.add(onlyFeature);
        }
        return arrayList;
    }

    @Override // org.jpmml.sparkml.FeatureConverter
    public void registerFeatures(SparkMLEncoder sparkMLEncoder) {
        ImputerModel imputerModel = (ImputerModel) getTransformer();
        List<Feature> encodeFeatures = encodeFeatures(sparkMLEncoder);
        String[] outputCols = imputerModel.getOutputCols();
        if (outputCols.length != encodeFeatures.size()) {
            throw new IllegalArgumentException();
        }
        for (int i = 0; i < encodeFeatures.size(); i++) {
            sparkMLEncoder.putFeatures(outputCols[i], Collections.singletonList(encodeFeatures.get(i)));
        }
    }

    public static MissingValueTreatmentMethod parseStrategy(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1078031094:
                if (str.equals("median")) {
                    z = true;
                    break;
                }
                break;
            case 3347397:
                if (str.equals("mean")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return MissingValueTreatmentMethod.AS_MEAN;
            case true:
                return MissingValueTreatmentMethod.AS_MEDIAN;
            default:
                throw new IllegalArgumentException(str);
        }
    }
}
