package org.jpmml.sparkml;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.xml.parsers.DocumentBuilder;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasLabelCol;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.param.shared.HasProbabilityCol;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldColumnPair;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Row;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DOMUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;

/* loaded from: input_file:org/jpmml/sparkml/ClassificationModelConverter.class */
public abstract class ClassificationModelConverter<T extends PredictionModel<Vector, T> & HasLabelCol & HasFeaturesCol & HasPredictionCol> extends ModelConverter<T> {
    public ClassificationModelConverter(T t) {
        super(t);
    }

    @Override // org.jpmml.sparkml.ModelConverter
    public MiningFunction getMiningFunction() {
        return MiningFunction.CLASSIFICATION;
    }

    @Override // org.jpmml.sparkml.ModelConverter
    public List<OutputField> registerOutputFields(Label label, SparkMLEncoder sparkMLEncoder) {
        HasProbabilityCol hasProbabilityCol = (PredictionModel) getTransformer();
        CategoricalLabel categoricalLabel = (CategoricalLabel) label;
        ArrayList arrayList = new ArrayList();
        String predictionCol = hasProbabilityCol.getPredictionCol();
        OutputField createPredictedField = ModelUtil.createPredictedField(FieldName.create("pmml(" + predictionCol + ")"), categoricalLabel.getDataType(), OpType.CATEGORICAL);
        arrayList.add(createPredictedField);
        ArrayList arrayList2 = new ArrayList();
        DocumentBuilder createDocumentBuilder = DOMUtil.createDocumentBuilder();
        InlineTable inlineTable = new InlineTable();
        List asList = Arrays.asList("input", "output");
        for (int i = 0; i < categoricalLabel.size(); i++) {
            String value = categoricalLabel.getValue(i);
            String valueOf = String.valueOf(i);
            arrayList2.add(valueOf);
            inlineTable.addRows(new Row[]{DOMUtil.createRow(createDocumentBuilder, asList, Arrays.asList(value, valueOf))});
        }
        OutputField expression = new OutputField(FieldName.create(predictionCol), DataType.DOUBLE).setOpType(OpType.CATEGORICAL).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(new MapValues().addFieldColumnPairs(new FieldColumnPair[]{new FieldColumnPair(createPredictedField.getName(), (String) asList.get(0))}).setOutputColumn((String) asList.get(1)).setInlineTable(inlineTable));
        arrayList.add(expression);
        sparkMLEncoder.putOnlyFeature(predictionCol, new CategoricalFeature(sparkMLEncoder, expression, arrayList2) { // from class: org.jpmml.sparkml.ClassificationModelConverter.1
            public ContinuousFeature toContinuousFeature() {
                return new ContinuousFeature(ensureEncoder(), this);
            }
        });
        if (hasProbabilityCol instanceof HasProbabilityCol) {
            String probabilityCol = hasProbabilityCol.getProbabilityCol();
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < categoricalLabel.size(); i2++) {
                String value2 = categoricalLabel.getValue(i2);
                OutputField createProbabilityField = ModelUtil.createProbabilityField(FieldName.create(probabilityCol + "(" + value2 + ")"), DataType.DOUBLE, value2);
                arrayList.add(createProbabilityField);
                arrayList3.add(new ContinuousFeature(sparkMLEncoder, createProbabilityField));
            }
            sparkMLEncoder.putFeatures(probabilityCol, arrayList3);
        }
        return arrayList;
    }
}
