package sklearn.compose;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.jpmml.converter.Feature;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.sklearn.CastFunction;
import org.jpmml.sklearn.ClassDictUtil;
import org.jpmml.sklearn.HasArray;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Initializer;
import sklearn.MultiTransformer;
import sklearn.Transformer;

/* loaded from: input_file:sklearn/compose/ColumnTransformer.class */
public class ColumnTransformer extends Initializer {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:sklearn/compose/ColumnTransformer$Drop.class */
    public static class Drop extends MultiTransformer {
        public static final Drop INSTANCE = new Drop();

        private Drop() {
            super(null, null);
        }

        @Override // sklearn.Transformer
        public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
            return Collections.emptyList();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:sklearn/compose/ColumnTransformer$PassThrough.class */
    public static class PassThrough extends MultiTransformer {
        public static final PassThrough INSTANCE = new PassThrough();

        private PassThrough() {
            super(null, null);
        }

        @Override // sklearn.Transformer
        public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
            return list;
        }
    }

    public ColumnTransformer(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Initializer
    public List<Feature> initializeFeatures(SkLearnEncoder skLearnEncoder) {
        return encodeFeatures(Collections.emptyList(), skLearnEncoder);
    }

    @Override // sklearn.Initializer, sklearn.Transformer
    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        List<Object[]> fittedTransformers = getFittedTransformers();
        ArrayList arrayList = new ArrayList();
        for (Object[] objArr : fittedTransformers) {
            arrayList.addAll(getTransformer(objArr).updateAndEncodeFeatures(getFeatures(objArr, list, skLearnEncoder), skLearnEncoder));
        }
        return arrayList;
    }

    public List<Object[]> getFittedTransformers() {
        return getTupleList("transformers_");
    }

    private static Transformer getTransformer(Object[] objArr) {
        Object obj = objArr[1];
        return "drop".equals(obj) ? Drop.INSTANCE : "passthrough".equals(obj) ? PassThrough.INSTANCE : new CastFunction<Transformer>(Transformer.class) { // from class: sklearn.compose.ColumnTransformer.1
            @Override // org.jpmml.sklearn.CastFunction
            protected String formatMessage(Object obj2) {
                return "The estimator object (" + ClassDictUtil.formatClass(obj2) + ") is not a supported Transformer";
            }
        }.apply(obj);
    }

    private static List<Feature> getFeatures(Object[] objArr, final List<Feature> list, final SkLearnEncoder skLearnEncoder) {
        Object obj = objArr[2];
        if (obj instanceof HasArray) {
            obj = ((HasArray) obj).getArrayContent();
        }
        return Lists.transform((List) obj, new Function<Object, Feature>() { // from class: sklearn.compose.ColumnTransformer.2
            /* renamed from: apply, reason: merged with bridge method [inline-methods] */
            public Feature m29apply(Object obj2) {
                if (!(obj2 instanceof String)) {
                    if (!(obj2 instanceof Integer)) {
                        throw new IllegalArgumentException("The column object (" + ClassDictUtil.formatClass(obj2) + ") is not a string or integer");
                    }
                    Integer num = (Integer) obj2;
                    return list.size() > 0 ? (Feature) list.get(num.intValue()) : createWildcardFeature(FieldName.create("x" + (num.intValue() + 1)));
                }
                String str = (String) obj2;
                if (list.size() <= 0) {
                    return createWildcardFeature(FieldName.create(str));
                }
                for (Feature feature : list) {
                    if (str.equals(feature.getName().getValue())) {
                        return feature;
                    }
                }
                throw new IllegalArgumentException("Column '" + str + "' is undefined");
            }

            private Feature createWildcardFeature(FieldName fieldName) {
                DataField dataField = skLearnEncoder.getDataField(fieldName);
                if (dataField == null) {
                    dataField = skLearnEncoder.createDataField(fieldName);
                }
                return new WildcardFeature(skLearnEncoder, dataField);
            }
        });
    }
}
