package org.tribuo.common.xgboost;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
import com.oracle.labs.mlrg.olcut.util.MutableDouble;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.common.xgboost.XGBoostTrainer;
import org.tribuo.common.xgboost.protos.XGBoostModelProto;
import org.tribuo.common.xgboost.protos.XGBoostOutputConverterProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/common/xgboost/XGBoostModel.class */
public final class XGBoostModel<T extends Output<T>> extends Model<T> {
    private static final long serialVersionUID = 4;
    private static final Logger logger = Logger.getLogger(XGBoostModel.class.getName());
    public static final int CURRENT_VERSION = 0;
    private final XGBoostOutputConverter<T> converter;
    private boolean regression41MappingFix;
    protected transient List<Booster> models;

    /* JADX INFO: Access modifiers changed from: package-private */
    public XGBoostModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, List<Booster> list, XGBoostOutputConverter<T> xGBoostOutputConverter) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, xGBoostOutputConverter.generatesProbabilities());
        this.converter = xGBoostOutputConverter;
        this.models = list;
        this.regression41MappingFix = true;
    }

    public static XGBoostModel<?> deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException, XGBoostError, IOException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        XGBoostModelProto unpack = any.unpack(XGBoostModelProto.class);
        XGBoostOutputConverter xGBoostOutputConverter = (XGBoostOutputConverter) ProtoUtil.deserialize(unpack.getConverter());
        Class<T> typeWitness = xGBoostOutputConverter.getTypeWitness();
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(typeWitness)) {
            throw new IllegalStateException("Invalid protobuf, output domain does not match the converter, found " + deserialize.outputDomain().getClass() + " and " + typeWitness);
        }
        ArrayList arrayList = new ArrayList();
        Iterator<ByteString> it = unpack.getModelsList().iterator();
        while (it.hasNext()) {
            arrayList.add(XGBoost.loadModel(it.next().toByteArray()));
        }
        if (arrayList.isEmpty()) {
            throw new IllegalStateException("Invalid protobuf, no XGBoost models were found");
        }
        return new XGBoostModel<>(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), deserialize.outputDomain(), arrayList, xGBoostOutputConverter);
    }

    public List<Booster> getInnerModels() {
        ArrayList arrayList = new ArrayList();
        Iterator<Booster> it = this.models.iterator();
        while (it.hasNext()) {
            arrayList.add(copyModel(it.next()));
        }
        return Collections.unmodifiableList(arrayList);
    }

    public void setNumThreads(int i) {
        if (i > -1) {
            try {
                Iterator<Booster> it = this.models.iterator();
                while (it.hasNext()) {
                    it.next().setParam("nthread", Integer.valueOf(i));
                }
            } catch (XGBoostError e) {
                logger.log(Level.SEVERE, "XGBoost threw an error", e);
                throw new IllegalStateException(e);
            }
        }
    }

    public List<Prediction<T>> predict(Dataset<T> dataset) {
        return predict(dataset.getData());
    }

    public List<Prediction<T>> predict(Iterable<Example<T>> iterable) {
        try {
            XGBoostTrainer.DMatrixTuple convertExamples = XGBoostTrainer.convertExamples(iterable, this.featureIDMap);
            ArrayList arrayList = new ArrayList();
            Iterator<Booster> it = this.models.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().predict(convertExamples.data));
            }
            return this.converter.convertBatchOutput(this.outputIDInfo, arrayList, convertExamples.numValidFeatures, convertExamples.examples);
        } catch (XGBoostError e) {
            logger.log(Level.SEVERE, "XGBoost threw an error", e);
            throw new IllegalStateException(e);
        }
    }

    public Prediction<T> predict(Example<T> example) {
        try {
            XGBoostTrainer.DMatrixTuple convertExample = XGBoostTrainer.convertExample(example, this.featureIDMap);
            ArrayList arrayList = new ArrayList();
            Iterator<Booster> it = this.models.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().predict(convertExample.data)[0]);
            }
            return this.converter.convertOutput(this.outputIDInfo, arrayList, convertExample.numValidFeatures[0], example);
        } catch (XGBoostError e) {
            logger.log(Level.SEVERE, "XGBoost threw an error", e);
            throw new IllegalStateException(e);
        }
    }

    public List<XGBoostFeatureImportance> getFeatureImportance() {
        return (List) this.models.stream().map(booster -> {
            return new XGBoostFeatureImportance(booster, this);
        }).collect(Collectors.toList());
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        int size;
        if (i < 0) {
            try {
                size = this.featureIDMap.size();
            } catch (XGBoostError e) {
                logger.log(Level.SEVERE, "XGBoost threw an error", e);
                return Collections.emptyMap();
            }
        } else {
            size = i;
        }
        int i2 = size;
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 < this.models.size(); i3++) {
            Booster booster = this.models.get(i3);
            HashMap hashMap2 = new HashMap();
            Iterator it = booster.getFeatureScore("").entrySet().iterator();
            while (it.hasNext()) {
                ((MutableDouble) hashMap2.computeIfAbsent(this.featureIDMap.get(Integer.parseInt(((String) ((Map.Entry) it.next()).getKey()).substring(1))).getName(), str -> {
                    return new MutableDouble();
                })).increment(((Integer) r0.getValue()).intValue());
            }
            Comparator comparingDouble = Comparator.comparingDouble(pair -> {
                return Math.abs(((Double) pair.getB()).doubleValue());
            });
            PriorityQueue priorityQueue = new PriorityQueue(i2, comparingDouble);
            for (Map.Entry entry : hashMap2.entrySet()) {
                Pair pair2 = new Pair((String) entry.getKey(), Double.valueOf(((MutableDouble) entry.getValue()).doubleValue()));
                if (priorityQueue.size() < i2) {
                    priorityQueue.offer(pair2);
                } else if (comparingDouble.compare(pair2, (Pair) priorityQueue.peek()) > 0) {
                    priorityQueue.poll();
                    priorityQueue.offer(pair2);
                }
            }
            ArrayList arrayList = new ArrayList();
            while (priorityQueue.size() > 0) {
                arrayList.add((Pair) priorityQueue.poll());
            }
            Collections.reverse(arrayList);
            if (this.models.size() == 1) {
                hashMap.put("ALL_OUTPUTS", arrayList);
            } else {
                hashMap.put(this.outputIDInfo.getOutput(i3).toString(), arrayList);
            }
        }
        return hashMap;
    }

    public List<String[]> getModelDump() {
        try {
            ArrayList arrayList = new ArrayList();
            Iterator<Booster> it = this.models.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().getModelDump("", true));
            }
            return arrayList;
        } catch (XGBoostError e) {
            throw new IllegalStateException((Throwable) e);
        }
    }

    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        return Optional.empty();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Booster copyModel(Booster booster) {
        try {
            return XGBoost.loadModel(booster.toByteArray());
        } catch (XGBoostError | IOException e) {
            throw new IllegalStateException("Unable to copy XGBoost model.", e);
        }
    }

    protected Model<T> copy(String str, ModelProvenance modelProvenance) {
        ArrayList arrayList = new ArrayList();
        Iterator<Booster> it = this.models.iterator();
        while (it.hasNext()) {
            arrayList.add(copyModel(it.next()));
        }
        return new XGBoostModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, arrayList, this.converter);
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m6serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        XGBoostModelProto.Builder newBuilder = XGBoostModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setConverter((XGBoostOutputConverterProto) this.converter.serialize());
        try {
            Iterator<Booster> it = this.models.iterator();
            while (it.hasNext()) {
                newBuilder.addModels(ByteString.copyFrom(it.next().toByteArray()));
            }
            ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
            newBuilder2.setSerializedData(Any.pack(newBuilder.m99build()));
            newBuilder2.setClassName(XGBoostModel.class.getName());
            newBuilder2.setVersion(0);
            return newBuilder2.build();
        } catch (XGBoostError e) {
            throw new IllegalStateException("Failed to serialize XGBoost model");
        }
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        try {
            objectOutputStream.writeInt(this.models.size());
            Iterator<Booster> it = this.models.iterator();
            while (it.hasNext()) {
                objectOutputStream.writeObject(it.next().toByteArray());
            }
        } catch (XGBoostError e) {
            throw new IOException("Failed to serialize the XGBoost model", e);
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        try {
            this.models = new ArrayList();
            int readInt = objectInputStream.readInt();
            for (int i = 0; i < readInt; i++) {
                this.models.add(XGBoost.loadModel((byte[]) objectInputStream.readObject()));
            }
            try {
                Class<?> cls = Class.forName("org.tribuo.regression.ImmutableRegressionInfo");
                String str = (String) ((PrimitiveProvenance) this.provenance.getTrainerProvenance().getInstanceValues().get("tribuo-version")).getValue();
                if (cls.isInstance(this.outputIDInfo) && !this.regression41MappingFix && (str.startsWith("4.0.0") || str.startsWith("4.0.1") || str.startsWith("4.0.2") || str.startsWith("4.1.0") || str.equals("4.1.1-SNAPSHOT"))) {
                    this.regression41MappingFix = true;
                    int[] iArr = (int[]) cls.getDeclaredMethod("getIDtoNaturalOrderMapping", new Class[0]).invoke(this.outputIDInfo, new Object[0]);
                    ArrayList arrayList = new ArrayList(this.models);
                    for (int i2 = 0; i2 < iArr.length; i2++) {
                        arrayList.set(i2, this.models.get(iArr[i2]));
                    }
                    this.models = arrayList;
                }
            } catch (ClassNotFoundException e) {
            } catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e2) {
                throw new RuntimeException("Failed to rewrite 4.1.0 or earlier regression model due to a reflection failure.", e2);
            }
        } catch (XGBoostError e3) {
            throw new IOException("Failed to deserialize the XGBoost model", e3);
        }
    }
}
