package h2o.estimators;

import hex.genmodel.MojoModel;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureList;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.h2o.Converter;
import org.jpmml.h2o.ConverterFactory;
import org.jpmml.h2o.H2OEncoder;
import org.jpmml.h2o.MojoModelUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.Encodable;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Estimator;
import sklearn.HasClasses;

/* loaded from: input_file:h2o/estimators/H2OEstimator.class */
public class H2OEstimator extends Estimator implements HasClasses, Encodable {
    private MojoModel mojoModel;

    public H2OEstimator(String str, String str2) {
        super(str, str2);
        this.mojoModel = null;
    }

    public MiningFunction getMiningFunction() {
        String estimatorType = getEstimatorType();
        boolean z = -1;
        switch (estimatorType.hashCode()) {
            case -281470431:
                if (estimatorType.equals("classifier")) {
                    z = false;
                    break;
                }
                break;
            case 322943626:
                if (estimatorType.equals("regressor")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return MiningFunction.CLASSIFICATION;
            case true:
                return MiningFunction.REGRESSION;
            default:
                throw new IllegalArgumentException(estimatorType);
        }
    }

    public boolean isSupervised() {
        String estimatorType = getEstimatorType();
        boolean z = -1;
        switch (estimatorType.hashCode()) {
            case -281470431:
                if (estimatorType.equals("classifier")) {
                    z = false;
                    break;
                }
                break;
            case 322943626:
                if (estimatorType.equals("regressor")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case true:
                return true;
            default:
                throw new IllegalArgumentException(estimatorType);
        }
    }

    public int getNumberOfOutputs() {
        String estimatorType = getEstimatorType();
        boolean z = -1;
        switch (estimatorType.hashCode()) {
            case -281470431:
                if (estimatorType.equals("classifier")) {
                    z = false;
                    break;
                }
                break;
            case 322943626:
                if (estimatorType.equals("regressor")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case true:
                return 1;
            default:
                throw new IllegalArgumentException(estimatorType);
        }
    }

    public List<?> getClasses() {
        MojoModel mojoModel = getMojoModel();
        String[] domainValues = mojoModel.getDomainValues(mojoModel.getResponseIdx());
        if (domainValues == null) {
            throw new IllegalArgumentException();
        }
        return Arrays.asList(domainValues);
    }

    public boolean hasProbabilityDistribution() {
        String estimatorType = getEstimatorType();
        boolean z = -1;
        switch (estimatorType.hashCode()) {
            case -281470431:
                if (estimatorType.equals("classifier")) {
                    z = false;
                    break;
                }
                break;
            case 322943626:
                if (estimatorType.equals("regressor")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return true;
            case true:
                return false;
            default:
                throw new IllegalArgumentException(estimatorType);
        }
    }

    public Label encodeLabel(List<String> list, SkLearnEncoder skLearnEncoder) {
        String estimatorType = getEstimatorType();
        ClassDictUtil.checkSize(1, new Collection[]{list});
        String str = list.get(0);
        boolean z = -1;
        switch (estimatorType.hashCode()) {
            case -281470431:
                if (estimatorType.equals("classifier")) {
                    z = false;
                    break;
                }
                break;
            case 322943626:
                if (estimatorType.equals("regressor")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                List<?> classes = getClasses();
                DataType dataType = TypeUtil.getDataType(classes, DataType.STRING);
                return str != null ? new CategoricalLabel(skLearnEncoder.createDataField(str, OpType.CATEGORICAL, dataType, classes)) : new CategoricalLabel(dataType, classes);
            case true:
                return str != null ? new ContinuousLabel(skLearnEncoder.createDataField(str, OpType.CONTINUOUS, DataType.DOUBLE)) : new ContinuousLabel(DataType.DOUBLE);
            default:
                throw new IllegalArgumentException(estimatorType);
        }
    }

    public Model encodeModel(Schema schema) {
        Feature findFeature;
        Converter<?> createConverter = createConverter();
        PMMLEncoder encoder = schema.getEncoder();
        Label label = schema.getLabel();
        FeatureList features = schema.getFeatures();
        List features2 = createConverter.encodeSchema(new H2OEncoder()).getFeatures();
        ArrayList arrayList = new ArrayList();
        Iterator it = features2.iterator();
        while (it.hasNext()) {
            String name = ((Feature) it.next()).getName();
            if (features instanceof FeatureList) {
                findFeature = features.resolveFeature(name);
            } else {
                findFeature = FeatureUtil.findFeature(features, name);
                if (findFeature == null) {
                    findFeature = (Feature) features.get(Integer.parseInt(name.substring(1)) - 1);
                }
            }
            arrayList.add(findFeature);
        }
        return createConverter.encodeModel(createConverter.toMojoModelSchema(new Schema(encoder, label, arrayList)));
    }

    public PMML encodePMML() {
        return createConverter().encodePMML();
    }

    public String getEstimatorType() {
        return getString("_estimator_type");
    }

    public byte[] getMojoBytes() {
        return (byte[]) get("_mojo_bytes", byte[].class);
    }

    public String getMojoPath() {
        return getString("_mojo_path");
    }

    public H2OEstimator setMojoPath(String str) {
        put("_mojo_path", str);
        return this;
    }

    private Converter<?> createConverter() {
        try {
            return ConverterFactory.newConverterFactory().newConverter(getMojoModel());
        } catch (Exception e) {
            throw new IllegalArgumentException(e);
        }
    }

    private MojoModel getMojoModel() {
        if (this.mojoModel == null) {
            this.mojoModel = loadMojoModel();
        }
        return this.mojoModel;
    }

    private MojoModel loadMojoModel() {
        MojoModel readFrom;
        try {
            if (containsKey("_mojo_bytes")) {
                ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(getMojoBytes());
                Throwable th = null;
                try {
                    try {
                        readFrom = MojoModelUtil.readFrom(byteArrayInputStream);
                        if (byteArrayInputStream != null) {
                            if (0 != 0) {
                                try {
                                    byteArrayInputStream.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                byteArrayInputStream.close();
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            } else {
                readFrom = MojoModelUtil.readFrom(new File(getMojoPath()), false);
            }
            return readFrom;
        } catch (Exception e) {
            throw new IllegalArgumentException(e);
        }
    }
}
