package sklearn;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.python.ClassDictUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:sklearn/Estimator.class */
public abstract class Estimator extends Step {
    private static final Logger logger = LoggerFactory.getLogger(Estimator.class);

    /* renamed from: sklearn.Estimator$1, reason: invalid class name */
    /* loaded from: input_file:sklearn/Estimator$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLUSTERING.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public Estimator(String str, String str2) {
        super(str, str2);
    }

    public abstract MiningFunction getMiningFunction();

    /* renamed from: encodeModel */
    public abstract Model mo1encodeModel(Schema schema);

    public int getNumberOfFeatures() {
        if (containsKey("n_features_")) {
            return getInteger("n_features_").intValue();
        }
        return -1;
    }

    @Override // sklearn.HasType
    public OpType getOpType() {
        return OpType.CONTINUOUS;
    }

    @Override // sklearn.HasType
    public DataType getDataType() {
        return DataType.DOUBLE;
    }

    public boolean isSupervised() {
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$MiningFunction[getMiningFunction().ordinal()]) {
            case 1:
            case 2:
                return true;
            case 3:
                return false;
            default:
                throw new IllegalArgumentException();
        }
    }

    public Model encode(Schema schema) {
        String pMMLName;
        Model mo1encodeModel = mo1encodeModel(schema);
        if (mo1encodeModel.getModelName() == null && (pMMLName = getPMMLName()) != null) {
            mo1encodeModel.setModelName(pMMLName);
        }
        if (mo1encodeModel.getAlgorithmName() == null) {
            mo1encodeModel.setAlgorithmName(getClassName());
        }
        addFeatureImportances(mo1encodeModel, schema);
        return mo1encodeModel;
    }

    public void addFeatureImportances(Model model, Schema schema) {
        List<? extends Number> pMMLFeatureImportances = getPMMLFeatureImportances();
        if (pMMLFeatureImportances == null) {
            pMMLFeatureImportances = getFeatureImportances();
        }
        ModelEncoder encoder = schema.getEncoder();
        List features = schema.getFeatures();
        if (pMMLFeatureImportances != null) {
            ClassDictUtil.checkSize(new Collection[]{features, pMMLFeatureImportances});
            for (int i = 0; i < features.size(); i++) {
                encoder.addFeatureImportance(model, ((Feature) features.get(i)).getName(), pMMLFeatureImportances.get(i));
            }
        }
    }

    public Object getOption(String str, Object obj) {
        Map<String, ?> pMMLOptions = getPMMLOptions();
        if (pMMLOptions != null && pMMLOptions.containsKey(str)) {
            return pMMLOptions.get(str);
        }
        if (!containsKey(str)) {
            return obj;
        }
        logger.warn("Attribute '" + ClassDictUtil.formatMember(this, "pmml_options_") + "' is not set. Falling back to the surrogate attribute '" + ClassDictUtil.formatMember(this, str) + "'");
        return get(str);
    }

    public boolean hasFeatureImportances() {
        return containsKey("feature_importances_") || containsKey("pmml_feature_importances_");
    }

    public List<? extends Number> getFeatureImportances() {
        if (containsKey("feature_importances_")) {
            return getNumberArray("feature_importances_");
        }
        return null;
    }

    public List<? extends Number> getPMMLFeatureImportances() {
        if (containsKey("pmml_feature_importances_")) {
            return getNumberArray("pmml_feature_importances_");
        }
        return null;
    }

    public Estimator setPMMLFeatureImportances(List<? extends Number> list) {
        put("pmml_feature_importances_", toArray(list));
        return this;
    }

    public Map<String, ?> getPMMLOptions() {
        if (get("pmml_options_") == null) {
            return null;
        }
        return getDict("pmml_options_");
    }

    public Estimator setPMMLOptions(Map<String, ?> map) {
        put("pmml_options_", map);
        return this;
    }

    public String getSkLearnVersion() {
        return getOptionalString("_sklearn_version");
    }
}
