package org.tribuo.interop.onnx;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxSequence;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.SequenceInfo;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;

/* loaded from: input_file:org/tribuo/interop/onnx/LabelTransformer.class */
public class LabelTransformer implements OutputTransformer<Label> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(LabelTransformer.class.getName());

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public Prediction<Label> transformToPrediction(List<OnnxValue> list, ImmutableOutputInfo<Label> immutableOutputInfo, int i, Example<Label> example) {
        float[][] batchPredictions = getBatchPredictions(list, immutableOutputInfo);
        if (batchPredictions.length != 1) {
            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + batchPredictions.length);
        }
        return generatePrediction(batchPredictions[0], immutableOutputInfo, i, example);
    }

    private Prediction<Label> generatePrediction(float[] fArr, ImmutableOutputInfo<Label> immutableOutputInfo, int i, Example<Label> example) {
        Label label = null;
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < fArr.length; i2++) {
            Label label2 = new Label(immutableOutputInfo.getOutput(i2).getLabel(), fArr[i2]);
            hashMap.put(label2.getLabel(), label2);
            if (label == null || label2.getScore() > label.getScore()) {
                label = label2;
            }
        }
        return new Prediction<>(label, hashMap, i, example, true);
    }

    /* renamed from: transformToOutput, reason: avoid collision after fix types in other method */
    public Label transformToOutput2(List<OnnxValue> list, ImmutableOutputInfo<Label> immutableOutputInfo) {
        float[][] batchPredictions = getBatchPredictions(list, immutableOutputInfo);
        if (batchPredictions.length != 1) {
            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + batchPredictions.length);
        }
        return generateLabel(batchPredictions[0], immutableOutputInfo);
    }

    private Label generateLabel(float[] fArr, ImmutableOutputInfo<Label> immutableOutputInfo) {
        int i = 0;
        float f = Float.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (fArr[i2] > f) {
                i = i2;
                f = fArr[i2];
            }
        }
        return new Label(immutableOutputInfo.getOutput(i).getLabel(), f);
    }

    private float[][] getBatchPredictions(List<OnnxValue> list, ImmutableOutputInfo<Label> immutableOutputInfo) {
        try {
            if (list.size() == 1) {
                if (!(list.get(0) instanceof OnnxTensor)) {
                    throw new IllegalArgumentException("Expected the first element to be a float OnnxTensor, found " + list.get(0));
                }
                OnnxTensor onnxTensor = list.get(0);
                if (onnxTensor.getInfo().type != OnnxJavaType.FLOAT) {
                    throw new IllegalArgumentException("Expected the first element to be a float OnnxTensor, found " + list.get(0));
                }
                long[] shape = onnxTensor.getInfo().getShape();
                if (shape.length == 2 && shape[1] == immutableOutputInfo.size()) {
                    return (float[][]) onnxTensor.getValue();
                }
                throw new IllegalArgumentException("Invalid shape for the probabilities tensor, expected shape [batchSize,numOutputs], found " + Arrays.toString(shape));
            }
            if (list.size() != 2) {
                throw new IllegalArgumentException("Unexpected number of OnnxValues returned, expected 1 or 2, received " + list.size());
            }
            if (!(list.get(1) instanceof OnnxSequence)) {
                throw new IllegalArgumentException("Expected a List<Map<Long,Float>>, received a " + list.get(1).getInfo().toString());
            }
            OnnxSequence onnxSequence = list.get(1);
            SequenceInfo info = onnxSequence.getInfo();
            if (!info.sequenceOfMaps || info.mapInfo.keyType != OnnxJavaType.INT64 || info.mapInfo.valueType != OnnxJavaType.FLOAT) {
                throw new IllegalArgumentException("Expected a List<Map<Long,Float>>, received a " + info.toString());
            }
            List<Map> value = onnxSequence.getValue();
            float[][] fArr = new float[value.size()][immutableOutputInfo.size()];
            int i = 0;
            for (Map map : value) {
                if (map.size() != immutableOutputInfo.size()) {
                    throw new IllegalArgumentException("Expected " + immutableOutputInfo.size() + " entries in the " + i + "th element, found " + map.size());
                }
                for (Map.Entry entry : map.entrySet()) {
                    Long l = (Long) entry.getKey();
                    if (l.longValue() != ((int) l.longValue())) {
                        throw new IllegalArgumentException("Key not representable as a Java int, this model is not supported. Expected value less than 2^32, received " + l);
                    }
                    fArr[i][(int) l.longValue()] = ((Float) entry.getValue()).floatValue();
                }
                i++;
            }
            return fArr;
        } catch (OrtException e) {
            throw new IllegalStateException("Failed to read a value out of the onnx result.", e);
        }
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public List<Prediction<Label>> transformToBatchPrediction(List<OnnxValue> list, ImmutableOutputInfo<Label> immutableOutputInfo, int[] iArr, List<Example<Label>> list2) {
        float[][] batchPredictions = getBatchPredictions(list, immutableOutputInfo);
        ArrayList arrayList = new ArrayList();
        if (batchPredictions.length != list2.size() || batchPredictions.length != iArr.length) {
            throw new IllegalArgumentException("Invalid number of predictions received from the ONNXExternalModel, expected " + iArr.length + ", received " + batchPredictions.length);
        }
        for (int i = 0; i < batchPredictions.length; i++) {
            arrayList.add(generatePrediction(batchPredictions[i], immutableOutputInfo, iArr[i], list2.get(i)));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public List<Label> transformToBatchOutput(List<OnnxValue> list, ImmutableOutputInfo<Label> immutableOutputInfo) {
        float[][] batchPredictions = getBatchPredictions(list, immutableOutputInfo);
        ArrayList arrayList = new ArrayList();
        for (float[] fArr : batchPredictions) {
            arrayList.add(generateLabel(fArr, immutableOutputInfo));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public boolean generatesProbabilities() {
        return true;
    }

    public String toString() {
        return "LabelTransformer()";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m4getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "OutputTransformer");
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public /* bridge */ /* synthetic */ Label transformToOutput(List list, ImmutableOutputInfo<Label> immutableOutputInfo) {
        return transformToOutput2((List<OnnxValue>) list, immutableOutputInfo);
    }
}
