package org.tribuo.interop.onnx;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtException;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.logging.Logger;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.multilabel.MultiLabel;

/* loaded from: input_file:org/tribuo/interop/onnx/MultiLabelTransformer.class */
public class MultiLabelTransformer implements OutputTransformer<MultiLabel> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(MultiLabelTransformer.class.getName());
    public static final double DEFAULT_THRESHOLD = 0.5d;

    @Config(description = "The threshold for determining if a label is present.")
    private double threshold;

    @Config(description = "Does this transformer produce probabilistic outputs.")
    private boolean generatesProbabilities;

    public MultiLabelTransformer() {
        this.threshold = 0.5d;
        this.generatesProbabilities = true;
    }

    public MultiLabelTransformer(double d, boolean z) {
        this.threshold = 0.5d;
        this.generatesProbabilities = true;
        this.threshold = d;
        this.generatesProbabilities = z;
        if (z) {
            if (d < 0.0d || d > 1.0d) {
                throw new IllegalArgumentException("Threshold must be between 0 and 1 to generate probabilities, found " + d);
            }
        }
    }

    public void postConfig() {
        if (this.generatesProbabilities) {
            if (this.threshold < 0.0d || this.threshold > 1.0d) {
                throw new PropertyException("", "threshold", "Threshold must be between 0 and 1 to generate probabilities, found " + this.threshold);
            }
        }
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public Prediction<MultiLabel> transformToPrediction(List<OnnxValue> list, ImmutableOutputInfo<MultiLabel> immutableOutputInfo, int i, Example<MultiLabel> example) {
        float[][] batchPredictions = getBatchPredictions(list);
        if (batchPredictions.length != 1) {
            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + batchPredictions.length);
        }
        return getPrediction(batchPredictions[0], immutableOutputInfo, i, example);
    }

    /* renamed from: transformToOutput, reason: avoid collision after fix types in other method */
    public MultiLabel transformToOutput2(List<OnnxValue> list, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        float[][] batchPredictions = getBatchPredictions(list);
        if (batchPredictions.length != 1) {
            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + batchPredictions.length);
        }
        return getOutput(batchPredictions[0], immutableOutputInfo);
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public List<Prediction<MultiLabel>> transformToBatchPrediction(List<OnnxValue> list, ImmutableOutputInfo<MultiLabel> immutableOutputInfo, int[] iArr, List<Example<MultiLabel>> list2) {
        float[][] batchPredictions = getBatchPredictions(list);
        if (batchPredictions.length != list2.size()) {
            throw new IllegalArgumentException("Supplied tensor has the wrong number of results, predictions.length = " + batchPredictions.length + ", examples.size() = " + list2.size());
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < batchPredictions.length; i++) {
            arrayList.add(getPrediction(batchPredictions[i], immutableOutputInfo, iArr[i], list2.get(i)));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public List<MultiLabel> transformToBatchOutput(List<OnnxValue> list, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        float[][] batchPredictions = getBatchPredictions(list);
        ArrayList arrayList = new ArrayList();
        for (float[] fArr : batchPredictions) {
            arrayList.add(getOutput(fArr, immutableOutputInfo));
        }
        return arrayList;
    }

    private MultiLabel getOutput(float[] fArr, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        if (fArr.length != immutableOutputInfo.size()) {
            throw new IllegalArgumentException("Supplied tensor has an incorrect number of dimensions, predictions[0].length = " + fArr.length + ", expected " + immutableOutputInfo.size());
        }
        HashSet hashSet = new HashSet();
        for (int i = 0; i < fArr.length; i++) {
            double d = fArr[i];
            if (d > this.threshold) {
                hashSet.add(new Label(immutableOutputInfo.getOutput(i).getLabelString(), d));
            }
        }
        return new MultiLabel(hashSet);
    }

    private Prediction<MultiLabel> getPrediction(float[] fArr, ImmutableOutputInfo<MultiLabel> immutableOutputInfo, int i, Example<MultiLabel> example) {
        if (fArr.length != immutableOutputInfo.size()) {
            throw new IllegalArgumentException("Supplied tensor has an incorrect number of dimensions, predictions[0].length = " + fArr.length + ", expected " + immutableOutputInfo.size());
        }
        HashMap hashMap = new HashMap(immutableOutputInfo.size());
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < fArr.length; i2++) {
            double d = fArr[i2];
            String labelString = immutableOutputInfo.getOutput(i2).getLabelString();
            Label label = new Label(labelString, d);
            if (d > this.threshold) {
                hashSet.add(label);
            }
            hashMap.put(labelString, new MultiLabel(label));
        }
        return new Prediction<>(new MultiLabel(hashSet), hashMap, i, example, true);
    }

    private float[][] getBatchPredictions(List<OnnxValue> list) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("Supplied output has incorrect number of elements, expected 1, found " + list.size());
        }
        OnnxTensor onnxTensor = (OnnxValue) list.get(0);
        if (!(onnxTensor instanceof OnnxTensor)) {
            throw new IllegalArgumentException("Supplied output was not an OnnxTensor, found " + onnxTensor.getClass().toString());
        }
        OnnxTensor onnxTensor2 = onnxTensor;
        long[] shape = onnxTensor2.getInfo().getShape();
        if (shape.length != 2) {
            throw new IllegalArgumentException("Expected shape [batchSize,numLabels], found " + Arrays.toString(shape));
        }
        try {
            if (onnxTensor2.getInfo().type == OnnxJavaType.FLOAT) {
                return (float[][]) onnxTensor2.getValue();
            }
            throw new IllegalArgumentException("Supplied output was an invalid tensor type, expected float, found " + onnxTensor2.getInfo().type);
        } catch (OrtException e) {
            throw new IllegalStateException("Failed to read tensor value", e);
        }
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public boolean generatesProbabilities() {
        return this.generatesProbabilities;
    }

    public String toString() {
        return "MultiLabelTransformer(threshold=" + this.threshold + ",generatesProbabilities=" + this.generatesProbabilities + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m7getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "OutputTransformer");
    }

    @Override // org.tribuo.interop.onnx.OutputTransformer
    public /* bridge */ /* synthetic */ MultiLabel transformToOutput(List list, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        return transformToOutput2((List<OnnxValue>) list, immutableOutputInfo);
    }
}
