package org.tribuo.interop.onnx;

import ai.onnxruntime.OnnxModelMetadata;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerializationException;
import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
import org.tribuo.interop.ExternalDatasetProvenance;
import org.tribuo.interop.ExternalModel;
import org.tribuo.interop.ExternalTrainerProvenance;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/interop/onnx/ONNXExternalModel.class */
public final class ONNXExternalModel<T extends Output<T>> extends ExternalModel<T, OnnxTensor, List<OnnxValue>> implements AutoCloseable {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(ONNXExternalModel.class.getName());
    private transient OrtEnvironment env;
    private transient OrtSession.SessionOptions options;
    private transient OrtSession session;
    private final byte[] modelArray;
    private final String inputName;
    private final ExampleTransformer featureTransformer;
    private final OutputTransformer<T> outputTransformer;

    private ONNXExternalModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, Map<String, Integer> map, byte[] bArr, OrtSession.SessionOptions sessionOptions, String str2, ExampleTransformer exampleTransformer, OutputTransformer<T> outputTransformer) throws OrtException {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, outputTransformer.generatesProbabilities(), map);
        this.modelArray = bArr;
        this.options = sessionOptions;
        this.inputName = str2;
        this.featureTransformer = exampleTransformer;
        this.outputTransformer = outputTransformer;
        this.env = OrtEnvironment.getEnvironment("tribuo-" + str);
        this.session = this.env.createSession(bArr, sessionOptions);
    }

    private ONNXExternalModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, int[] iArr, int[] iArr2, byte[] bArr, OrtSession.SessionOptions sessionOptions, String str2, ExampleTransformer exampleTransformer, OutputTransformer<T> outputTransformer) throws OrtException {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, iArr, iArr2, outputTransformer.generatesProbabilities());
        this.modelArray = bArr;
        this.options = sessionOptions;
        this.inputName = str2;
        this.featureTransformer = exampleTransformer;
        this.outputTransformer = outputTransformer;
        this.env = OrtEnvironment.getEnvironment("tribuo-" + str);
        this.session = this.env.createSession(bArr, sessionOptions);
    }

    public synchronized void rebuild(OrtSession.SessionOptions sessionOptions) throws OrtException {
        this.session.close();
        if (this.options != null) {
            this.options.close();
        }
        this.options = sessionOptions;
        this.env.createSession(this.modelArray, sessionOptions);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: convertFeatures, reason: merged with bridge method [inline-methods] */
    public OnnxTensor m10convertFeatures(SparseVector sparseVector) {
        try {
            return this.featureTransformer.transform(this.env, sparseVector);
        } catch (OrtException e) {
            throw new IllegalStateException("Failed to construct input OnnxTensor", e);
        }
    }

    protected OnnxTensor convertFeaturesList(List<SparseVector> list) {
        try {
            return this.featureTransformer.transform(this.env, list);
        } catch (OrtException e) {
            throw new IllegalStateException("Failed to construct input OnnxTensor", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<OnnxValue> externalPrediction(OnnxTensor onnxTensor) {
        try {
            OrtSession.Result run = this.session.run(Collections.singletonMap(this.inputName, onnxTensor));
            onnxTensor.close();
            ArrayList arrayList = new ArrayList();
            Iterator it = run.iterator();
            while (it.hasNext()) {
                arrayList.add(((Map.Entry) it.next()).getValue());
            }
            return arrayList;
        } catch (OrtException e) {
            throw new IllegalStateException("Failed to execute ONNX model", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Prediction<T> convertOutput(List<OnnxValue> list, int i, Example<T> example) {
        Prediction<T> transformToPrediction = this.outputTransformer.transformToPrediction(list, this.outputIDInfo, i, example);
        OnnxValue.close(list);
        return transformToPrediction;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<Prediction<T>> convertOutput(List<OnnxValue> list, int[] iArr, List<Example<T>> list2) {
        List<Prediction<T>> transformToBatchPrediction = this.outputTransformer.transformToBatchPrediction(list, this.outputIDInfo, iArr, list2);
        OnnxValue.close(list);
        return transformToBatchPrediction;
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    protected synchronized Model<T> copy(String str, ModelProvenance modelProvenance) {
        try {
            return new ONNXExternalModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.featureForwardMapping, this.featureBackwardMapping, Arrays.copyOf(this.modelArray, this.modelArray.length), this.options, this.inputName, this.featureTransformer, this.outputTransformer);
        } catch (OrtException e) {
            throw new IllegalStateException("Failed to copy ONNX model", e);
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.session != null) {
            try {
                this.session.close();
            } catch (OrtException e) {
                logger.log(Level.SEVERE, "Exception thrown when closing session", e);
            }
        }
        if (this.options != null) {
            this.options.close();
        }
        if (this.env != null) {
            try {
                this.env.close();
            } catch (OrtException e2) {
                logger.log(Level.SEVERE, "Exception thrown when closing environment", e2);
            }
        }
    }

    public Optional<ModelProvenance> getTribuoProvenance() {
        try {
            Optional customMetadataValue = this.session.getMetadata().getCustomMetadataValue("TRIBUO_PROVENANCE");
            if (!customMetadataValue.isPresent()) {
                return Optional.empty();
            }
            ModelProvenance deserializeAndUnmarshal = ONNXExportable.SERIALIZER.deserializeAndUnmarshal((String) customMetadataValue.get());
            if (deserializeAndUnmarshal instanceof ModelProvenance) {
                return Optional.of(deserializeAndUnmarshal);
            }
            logger.log(Level.WARNING, "Found invalid provenance object, " + deserializeAndUnmarshal.toString());
            return Optional.empty();
        } catch (OrtException e) {
            logger.log(Level.WARNING, "ORTException when reading session metadata", e);
            return Optional.empty();
        } catch (ProvenanceSerializationException e2) {
            logger.log(Level.WARNING, "Failed to parse provenance from value.", e2);
            return Optional.empty();
        }
    }

    public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> outputFactory, Map<String, Integer> map, Map<T, Integer> map2, ExampleTransformer exampleTransformer, OutputTransformer<T> outputTransformer, OrtSession.SessionOptions sessionOptions, String str, String str2) throws OrtException {
        return createOnnxModel(outputFactory, map, map2, exampleTransformer, outputTransformer, sessionOptions, Paths.get(str, new String[0]), str2);
    }

    /* JADX WARN: Failed to calculate best type for var: r29v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r29v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Failed to calculate best type for var: r30v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r30v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 29, insn: 0x01f0: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r29 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:67:0x01f0 */
    /* JADX WARN: Not initialized variable reg: 30, insn: 0x01f5: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r30 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:69:0x01f5 */
    /* JADX WARN: Type inference failed for: r29v0, types: [ai.onnxruntime.OrtEnvironment] */
    /* JADX WARN: Type inference failed for: r30v0, types: [java.lang.Throwable] */
    public static <T extends Output<T>> ONNXExternalModel<T> createOnnxModel(OutputFactory<T> outputFactory, Map<String, Integer> map, Map<T, Integer> map2, ExampleTransformer exampleTransformer, OutputTransformer<T> outputTransformer, OrtSession.SessionOptions sessionOptions, Path path, String str) throws OrtException {
        ?? r29;
        ?? r30;
        try {
            byte[] readAllBytes = Files.readAllBytes(path);
            URL url = path.toUri().toURL();
            ImmutableFeatureMap createFeatureMap = ExternalModel.createFeatureMap(map.keySet());
            ImmutableOutputInfo createOutputInfo = ExternalModel.createOutputInfo(outputFactory, map2);
            OffsetDateTime now = OffsetDateTime.now();
            ExternalTrainerProvenance externalTrainerProvenance = new ExternalTrainerProvenance(url);
            ExternalDatasetProvenance externalDatasetProvenance = new ExternalDatasetProvenance("unknown-external-data", outputFactory, false, map.size(), map2.size());
            HashMap hashMap = new HashMap();
            hashMap.put("input-name", new StringProvenance("input-name", str));
            try {
                try {
                    OrtEnvironment environment = OrtEnvironment.getEnvironment();
                    Throwable th = null;
                    OrtSession createSession = environment.createSession(readAllBytes);
                    Throwable th2 = null;
                    try {
                        try {
                            OnnxModelMetadata metadata = createSession.getMetadata();
                            hashMap.put("model-producer", new StringProvenance("model-producer", metadata.getProducerName()));
                            hashMap.put("model-domain", new StringProvenance("model-domain", metadata.getDomain()));
                            hashMap.put("model-description", new StringProvenance("model-description", metadata.getDescription()));
                            hashMap.put("model-graphname", new StringProvenance("model-graphname", metadata.getGraphName()));
                            hashMap.put("model-version", new LongProvenance("model-version", metadata.getVersion()));
                            for (Map.Entry entry : metadata.getCustomMetadata().entrySet()) {
                                if (!((String) entry.getKey()).equals("TRIBUO_PROVENANCE")) {
                                    String str2 = "model-metadata-" + ((String) entry.getKey());
                                    hashMap.put(str2, new StringProvenance(str2, (String) entry.getValue()));
                                }
                            }
                            if (createSession != null) {
                                if (0 != 0) {
                                    try {
                                        createSession.close();
                                    } catch (Throwable th3) {
                                        th2.addSuppressed(th3);
                                    }
                                } else {
                                    createSession.close();
                                }
                            }
                            if (environment != null) {
                                if (0 != 0) {
                                    try {
                                        environment.close();
                                    } catch (Throwable th4) {
                                        th.addSuppressed(th4);
                                    }
                                } else {
                                    environment.close();
                                }
                            }
                            return new ONNXExternalModel<>("external-model", new ModelProvenance(ONNXExternalModel.class.getName(), now, externalDatasetProvenance, externalTrainerProvenance, hashMap), createFeatureMap, createOutputInfo, map, readAllBytes, sessionOptions, str, exampleTransformer, outputTransformer);
                        } finally {
                        }
                    } catch (Throwable th5) {
                        if (createSession != null) {
                            if (th2 != null) {
                                try {
                                    createSession.close();
                                } catch (Throwable th6) {
                                    th2.addSuppressed(th6);
                                }
                            } else {
                                createSession.close();
                            }
                        }
                        throw th5;
                    }
                } catch (OrtException e) {
                    throw new IllegalArgumentException("Failed to load model and read metadata from path " + path, e);
                }
            } catch (Throwable th7) {
                if (r29 != 0) {
                    if (r30 != 0) {
                        try {
                            r29.close();
                        } catch (Throwable th8) {
                            r30.addSuppressed(th8);
                        }
                    } else {
                        r29.close();
                    }
                }
                throw th7;
            }
        } catch (IOException e2) {
            throw new IllegalArgumentException("Unable to load model from path " + path, e2);
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        try {
            this.env = OrtEnvironment.getEnvironment();
            this.options = new OrtSession.SessionOptions();
            this.session = this.env.createSession(this.modelArray, this.options);
        } catch (OrtException e) {
            throw new IllegalStateException("Could not construct ONNX Runtime session during deserialization.");
        }
    }

    /* renamed from: convertFeaturesList, reason: collision with other method in class */
    protected /* bridge */ /* synthetic */ Object m9convertFeaturesList(List list) {
        return convertFeaturesList((List<SparseVector>) list);
    }
}
