package org.tribuo.common.xgboost;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
import org.tribuo.common.xgboost.protos.XGBoostExternalModelProto;
import org.tribuo.common.xgboost.protos.XGBoostOutputConverterProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.interop.ExternalDatasetProvenance;
import org.tribuo.interop.ExternalModel;
import org.tribuo.interop.ExternalTrainerProvenance;
import org.tribuo.math.la.SparseVector;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/common/xgboost/XGBoostExternalModel.class */
public final class XGBoostExternalModel<T extends Output<T>> extends ExternalModel<T, DMatrix, float[][]> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(XGBoostExternalModel.class.getName());
    public static final int CURRENT_VERSION = 0;
    private final XGBoostOutputConverter<T> converter;
    protected transient Booster model;

    private XGBoostExternalModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, Map<String, Integer> map, Booster booster, XGBoostOutputConverter<T> xGBoostOutputConverter) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, xGBoostOutputConverter.generatesProbabilities(), map);
        this.model = booster;
        this.converter = xGBoostOutputConverter;
    }

    private XGBoostExternalModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, int[] iArr, int[] iArr2, Booster booster, XGBoostOutputConverter<T> xGBoostOutputConverter) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, iArr, iArr2, xGBoostOutputConverter.generatesProbabilities());
        this.model = booster;
        this.converter = xGBoostOutputConverter;
    }

    public static XGBoostExternalModel<?> deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException, XGBoostError, IOException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        XGBoostExternalModelProto unpack = any.unpack(XGBoostExternalModelProto.class);
        XGBoostOutputConverter xGBoostOutputConverter = (XGBoostOutputConverter) ProtoUtil.deserialize(unpack.getConverter());
        Class<T> typeWitness = xGBoostOutputConverter.getTypeWitness();
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(typeWitness)) {
            throw new IllegalStateException("Invalid protobuf, output domain does not match the converter, found " + deserialize.outputDomain().getClass() + " and " + typeWitness);
        }
        int[] primitiveInt = Util.toPrimitiveInt(unpack.getForwardFeatureMappingList());
        int[] primitiveInt2 = Util.toPrimitiveInt(unpack.getBackwardFeatureMappingList());
        if (!validateFeatureMapping(primitiveInt, primitiveInt2, deserialize.featureDomain())) {
            throw new IllegalStateException("Invalid protobuf, external<->Tribuo feature mapping does not form a bijection");
        }
        return new XGBoostExternalModel<>(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), deserialize.outputDomain(), primitiveInt, primitiveInt2, XGBoost.loadModel(unpack.getModel().toByteArray()), xGBoostOutputConverter);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: convertFeatures, reason: merged with bridge method [inline-methods] */
    public DMatrix m2convertFeatures(SparseVector sparseVector) {
        try {
            return XGBoostTrainer.convertSparseVector(sparseVector);
        } catch (XGBoostError e) {
            logger.severe("XGBoost threw an error while constructing the DMatrix.");
            throw new IllegalStateException((Throwable) e);
        }
    }

    protected DMatrix convertFeaturesList(List<SparseVector> list) {
        try {
            return XGBoostTrainer.convertSparseVectors(list);
        } catch (XGBoostError e) {
            logger.severe("XGBoost threw an error while constructing the DMatrix.");
            throw new IllegalStateException((Throwable) e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float[][] externalPrediction(DMatrix dMatrix) {
        try {
            return this.model.predict(dMatrix);
        } catch (XGBoostError e) {
            logger.severe("XGBoost threw an error while predicting.");
            throw new IllegalStateException((Throwable) e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Prediction<T> convertOutput(float[][] fArr, int i, Example<T> example) {
        return this.converter.convertOutput(this.outputIDInfo, Collections.singletonList(fArr[0]), i, example);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<Prediction<T>> convertOutput(float[][] fArr, int[] iArr, List<Example<T>> list) {
        return this.converter.convertBatchOutput(this.outputIDInfo, Collections.singletonList(fArr), iArr, (Example[]) list.toArray(new Example[0]));
    }

    public List<XGBoostFeatureImportance> getFeatureImportance() {
        return Collections.singletonList(new XGBoostFeatureImportance(this.model, this));
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        int size;
        if (i < 0) {
            try {
                size = this.featureIDMap.size();
            } catch (XGBoostError e) {
                logger.log(Level.SEVERE, "XGBoost threw an error", e);
                return Collections.emptyMap();
            }
        } else {
            size = i;
        }
        int i2 = size;
        Map featureScore = this.model.getFeatureScore("");
        Comparator comparingDouble = Comparator.comparingDouble(pair -> {
            return Math.abs(((Double) pair.getB()).doubleValue());
        });
        PriorityQueue priorityQueue = new PriorityQueue(i2, comparingDouble);
        Iterator it = featureScore.entrySet().iterator();
        while (it.hasNext()) {
            Pair pair2 = new Pair(this.featureIDMap.get(this.featureBackwardMapping[Integer.parseInt(((String) ((Map.Entry) it.next()).getKey()).substring(1))]).getName(), Double.valueOf(((Integer) r0.getValue()).intValue()));
            if (priorityQueue.size() < i2) {
                priorityQueue.offer(pair2);
            } else if (comparingDouble.compare(pair2, (Pair) priorityQueue.peek()) > 0) {
                priorityQueue.poll();
                priorityQueue.offer(pair2);
            }
        }
        ArrayList arrayList = new ArrayList();
        while (priorityQueue.size() > 0) {
            arrayList.add((Pair) priorityQueue.poll());
        }
        Collections.reverse(arrayList);
        HashMap hashMap = new HashMap();
        hashMap.put("ALL_OUTPUTS", arrayList);
        return hashMap;
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m4serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        XGBoostExternalModelProto.Builder newBuilder = XGBoostExternalModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setConverter((XGBoostOutputConverterProto) this.converter.serialize());
        newBuilder.addAllForwardFeatureMapping((Iterable) Arrays.stream(this.featureForwardMapping).boxed().collect(Collectors.toList()));
        newBuilder.addAllBackwardFeatureMapping((Iterable) Arrays.stream(this.featureBackwardMapping).boxed().collect(Collectors.toList()));
        try {
            newBuilder.setModel(ByteString.copyFrom(this.model.toByteArray()));
            ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
            newBuilder2.setSerializedData(Any.pack(newBuilder.m52build()));
            newBuilder2.setClassName(XGBoostExternalModel.class.getName());
            newBuilder2.setVersion(0);
            return newBuilder2.build();
        } catch (XGBoostError e) {
            throw new IllegalStateException("Failed to serialize XGBoost model");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public XGBoostExternalModel<T> m3copy(String str, ModelProvenance modelProvenance) {
        return new XGBoostExternalModel<>(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.featureForwardMapping, this.featureBackwardMapping, XGBoostModel.copyModel(this.model), this.converter);
    }

    public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> outputFactory, Map<String, Integer> map, Map<T, Integer> map2, XGBoostOutputConverter<T> xGBoostOutputConverter, String str) {
        try {
            return createXGBoostModel(outputFactory, map, map2, xGBoostOutputConverter, XGBoost.loadModel(str), new ExternalTrainerProvenance(new File(str).toURI().toURL()), Collections.emptyMap());
        } catch (XGBoostError | MalformedURLException e) {
            throw new IllegalArgumentException("Unable to load model from path " + str, e);
        }
    }

    public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> outputFactory, Map<String, Integer> map, Map<T, Integer> map2, XGBoostOutputConverter<T> xGBoostOutputConverter, Path path) {
        try {
            return createXGBoostModel(outputFactory, map, map2, xGBoostOutputConverter, XGBoost.loadModel(Files.newInputStream(path, new OpenOption[0])), new ExternalTrainerProvenance(path.toUri().toURL()), Collections.emptyMap());
        } catch (XGBoostError | IOException e) {
            throw new IllegalArgumentException("Unable to load model from path " + path, e);
        }
    }

    @Deprecated
    public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> outputFactory, Map<String, Integer> map, Map<T, Integer> map2, XGBoostOutputConverter<T> xGBoostOutputConverter, Booster booster, URL url) {
        return createXGBoostModel(outputFactory, map, map2, xGBoostOutputConverter, booster, new ExternalTrainerProvenance(url), Collections.emptyMap());
    }

    public static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> outputFactory, Map<String, Integer> map, Map<T, Integer> map2, XGBoostOutputConverter<T> xGBoostOutputConverter, Booster booster, Map<String, Provenance> map3) {
        try {
            return createXGBoostModel(outputFactory, map, map2, xGBoostOutputConverter, booster, new ExternalTrainerProvenance(booster.toByteArray()), map3);
        } catch (XGBoostError e) {
            throw new IllegalStateException("Unable to extract byte array from booster", e);
        }
    }

    private static <T extends Output<T>> XGBoostExternalModel<T> createXGBoostModel(OutputFactory<T> outputFactory, Map<String, Integer> map, Map<T, Integer> map2, XGBoostOutputConverter<T> xGBoostOutputConverter, Booster booster, ExternalTrainerProvenance externalTrainerProvenance, Map<String, Provenance> map3) {
        return new XGBoostExternalModel<>("external-model", new ModelProvenance(XGBoostExternalModel.class.getName(), OffsetDateTime.now(), new ExternalDatasetProvenance("unknown-external-data", outputFactory, false, map.size(), map2.size()), externalTrainerProvenance, map3), ExternalModel.createFeatureMap(map.keySet()), ExternalModel.createOutputInfo(outputFactory, map2), map, booster, xGBoostOutputConverter);
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        try {
            objectOutputStream.writeObject(this.model.toByteArray());
        } catch (XGBoostError e) {
            throw new IOException("Failed to serialize the XGBoost model", e);
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        try {
            this.model = XGBoost.loadModel(new ByteArrayInputStream((byte[]) objectInputStream.readObject()));
        } catch (XGBoostError e) {
            throw new IOException("Failed to deserialize the XGBoost model", e);
        }
    }

    /* renamed from: convertFeaturesList, reason: collision with other method in class */
    protected /* bridge */ /* synthetic */ Object m1convertFeaturesList(List list) {
        return convertFeaturesList((List<SparseVector>) list);
    }
}
