package sklearn.tree;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.HasExtensions;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.CountingBranchNode;
import org.dmg.pmml.tree.CountingLeafNode;
import org.dmg.pmml.tree.DefaultNodeTransformer;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.NodeTransformer;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.visitors.AbstractExtender;
import org.jpmml.model.ValueUtil;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.visitors.TreeModelCompactor;
import org.jpmml.sklearn.visitors.TreeModelFlattener;
import sklearn.Estimator;
import sklearn.HasEstimatorEnsemble;

/* loaded from: input_file:sklearn/tree/TreeModelUtil.class */
public class TreeModelUtil {
    private TreeModelUtil() {
    }

    public static <E extends Estimator & HasTreeOptions, M extends Model> M transform(E e, M m) {
        Boolean bool = (Boolean) e.getOption(HasTreeOptions.OPTION_WINNER_ID, Boolean.FALSE);
        Map map = (Map) e.getOption(HasTreeOptions.OPTION_NODE_EXTENSIONS, null);
        Boolean bool2 = (Boolean) e.getOption(HasTreeOptions.OPTION_NODE_ID, bool);
        boolean z = map != null || bool2.booleanValue();
        Boolean bool3 = (Boolean) e.getOption(HasTreeOptions.OPTION_COMPACT, z ? Boolean.FALSE : Boolean.TRUE);
        Boolean bool4 = (Boolean) e.getOption(HasTreeOptions.OPTION_FLAT, Boolean.FALSE);
        if (z && (bool3.booleanValue() || bool4.booleanValue())) {
            throw new IllegalArgumentException("Conflicting tree model options");
        }
        if (Boolean.TRUE.equals(bool)) {
            ModelUtil.ensureOutput(m).addOutputFields(new OutputField[]{ModelUtil.createEntityIdField(FieldName.create("nodeId")).setDataType(DataType.INTEGER)});
        }
        ArrayList arrayList = new ArrayList();
        if (Boolean.TRUE.equals(bool3)) {
            arrayList.add(new TreeModelCompactor());
        }
        if (Boolean.TRUE.equals(bool4)) {
            arrayList.add(new TreeModelFlattener());
        }
        if (map != null) {
            for (Map.Entry entry : map.entrySet()) {
                String str = (String) entry.getKey();
                final Map map2 = (Map) entry.getValue();
                arrayList.add(new AbstractExtender(str) { // from class: sklearn.tree.TreeModelUtil.1
                    private NodeTransformer nodeTransformer = DefaultNodeTransformer.INSTANCE;

                    public VisitorAction visit(TreeModel treeModel) {
                        treeModel.setNode(ensureExtensibility(treeModel.getNode()));
                        return super.visit(treeModel);
                    }

                    public VisitorAction visit(Node node) {
                        if (node.hasNodes()) {
                            ListIterator listIterator = node.getNodes().listIterator();
                            while (listIterator.hasNext()) {
                                listIterator.set(ensureExtensibility((Node) listIterator.next()));
                            }
                        }
                        Object value = getValue(node);
                        if (value != null) {
                            addExtension((HasExtensions) node, ValueUtil.toString(ScalarUtil.decode(value)));
                        }
                        return super.visit(node);
                    }

                    private Node ensureExtensibility(Node node) {
                        if (!(node instanceof HasExtensions) && getValue(node) != null) {
                            return this.nodeTransformer.toComplexNode(node);
                        }
                        return node;
                    }

                    private Object getValue(Node node) {
                        return map2.get(org.jpmml.converter.ValueUtil.asInteger((Number) node.getId()));
                    }
                });
            }
        }
        if (Boolean.FALSE.equals(bool2)) {
            arrayList.add(new AbstractVisitor() { // from class: sklearn.tree.TreeModelUtil.2
                public VisitorAction visit(Node node) {
                    node.setId((Object) null);
                    return super.visit(node);
                }
            });
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((Visitor) it.next()).applyTo(m);
        }
        return m;
    }

    public static <E extends Estimator & HasEstimatorEnsemble<T>, T extends Estimator & HasTree> List<TreeModel> encodeTreeModelSegmentation(E e, MiningFunction miningFunction, Schema schema) {
        return encodeTreeModelSegmentation(e, new PredicateManager(), new ScoreDistributionManager(), miningFunction, schema);
    }

    public static <E extends Estimator & HasEstimatorEnsemble<T>, T extends Estimator & HasTree> List<TreeModel> encodeTreeModelSegmentation(E e, final PredicateManager predicateManager, final ScoreDistributionManager scoreDistributionManager, final MiningFunction miningFunction, Schema schema) {
        List estimators = ((HasEstimatorEnsemble) e).getEstimators();
        final Schema anonymousSchema = schema.toAnonymousSchema();
        return (List) estimators.stream().map(new Function<T, TreeModel>() { // from class: sklearn.tree.TreeModelUtil.3
            /* JADX WARN: Incorrect types in method signature: (TT;)Lorg/dmg/pmml/tree/TreeModel; */
            @Override // java.util.function.Function
            public TreeModel apply(Estimator estimator) {
                return TreeModelUtil.encodeTreeModel(estimator, predicateManager, scoreDistributionManager, miningFunction, TreeModelUtil.toTreeModelSchema(estimator.getDataType(), anonymousSchema));
            }
        }).collect(Collectors.toList());
    }

    public static <E extends Estimator & HasTree> TreeModel encodeTreeModel(E e, MiningFunction miningFunction, Schema schema) {
        return encodeTreeModel(e, new PredicateManager(), new ScoreDistributionManager(), miningFunction, schema);
    }

    public static <E extends Estimator & HasTree> TreeModel encodeTreeModel(E e, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, MiningFunction miningFunction, Schema schema) {
        Tree tree = e.getTree();
        TreeModel splitCharacteristic = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), encodeNode(True.INSTANCE, predicateManager, scoreDistributionManager, 0, tree.getChildrenLeft(), tree.getChildrenRight(), tree.getFeature(), tree.getThreshold(), tree.getValues(), miningFunction, schema)).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        ClassDictUtil.clearContent(tree);
        return splitCharacteristic;
    }

    private static Node encodeNode(Predicate predicate, PredicateManager predicateManager, ScoreDistributionManager scoreDistributionManager, int i, int[] iArr, int[] iArr2, int[] iArr3, double[] dArr, double[] dArr2, MiningFunction miningFunction, Schema schema) {
        Node id;
        Predicate createSimplePredicate;
        Predicate createSimplePredicate2;
        ClassifierNode countingBranchNode;
        Integer valueOf = Integer.valueOf(i);
        int i2 = iArr3[i];
        if (i2 < 0) {
            if (MiningFunction.CLASSIFICATION.equals(miningFunction)) {
                CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
                double[] row = getRow(dArr2, iArr.length, categoricalLabel.size(), i);
                double d = 0.0d;
                Object obj = null;
                double d2 = -1.7976931348623157E308d;
                for (int i3 = 0; i3 < row.length; i3++) {
                    double d3 = row[i3];
                    d += d3;
                    if (d3 > d2) {
                        obj = categoricalLabel.getValue(i3);
                        d2 = d3;
                    }
                }
                id = new ClassifierNode(obj, predicate).setId(valueOf).setRecordCount(Double.valueOf(d));
                id.getScoreDistributions().addAll(scoreDistributionManager.createScoreDistribution(categoricalLabel, row));
            } else {
                if (!MiningFunction.REGRESSION.equals(miningFunction)) {
                    throw new IllegalArgumentException();
                }
                id = new CountingLeafNode(Double.valueOf(dArr2[i]), predicate).setId(valueOf);
            }
            return id;
        }
        BinaryFeature feature = schema.getFeature(i2);
        double d4 = dArr[i];
        if (feature instanceof BinaryFeature) {
            BinaryFeature binaryFeature = feature;
            if (d4 < 0.0d || d4 > 1.0d) {
                throw new IllegalArgumentException();
            }
            Object value = binaryFeature.getValue();
            createSimplePredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
            createSimplePredicate2 = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.EQUAL, value);
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature(DataType.FLOAT).toContinuousFeature(DataType.DOUBLE);
            Double valueOf2 = Double.valueOf(d4);
            createSimplePredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, valueOf2);
            createSimplePredicate2 = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, valueOf2);
        }
        int i4 = iArr[i];
        int i5 = iArr2[i];
        Node encodeNode = encodeNode(createSimplePredicate, predicateManager, scoreDistributionManager, i4, iArr, iArr2, iArr3, dArr, dArr2, miningFunction, schema);
        Node encodeNode2 = encodeNode(createSimplePredicate2, predicateManager, scoreDistributionManager, i5, iArr, iArr2, iArr3, dArr, dArr2, miningFunction, schema);
        if (MiningFunction.CLASSIFICATION.equals(miningFunction)) {
            countingBranchNode = new ClassifierNode((Object) null, predicate);
        } else {
            if (!MiningFunction.REGRESSION.equals(miningFunction)) {
                throw new IllegalArgumentException();
            }
            countingBranchNode = new CountingBranchNode((Object) null, predicate);
        }
        countingBranchNode.setId(valueOf).addNodes(encodeNode, encodeNode2);
        return countingBranchNode;
    }

    public static Schema toTreeModelSchema(final DataType dataType, Schema schema) {
        return schema.toTransformedSchema(new Function<Feature, Feature>() { // from class: sklearn.tree.TreeModelUtil.4
            @Override // java.util.function.Function
            public Feature apply(Feature feature) {
                return feature instanceof BinaryFeature ? (BinaryFeature) feature : feature.toContinuousFeature(dataType);
            }
        });
    }

    private static double[] getRow(double[] dArr, int i, int i2, int i3) {
        if (dArr.length != i * i2) {
            throw new IllegalArgumentException("Expected " + (i * i2) + " element(s), got " + dArr.length + " element(s)");
        }
        double[] dArr2 = new double[i2];
        System.arraycopy(dArr, i3 * i2, dArr2, 0, i2);
        return dArr2;
    }
}
