package org.tribuo.interop.onnx;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.List;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;

/* loaded from: input_file:org/tribuo/interop/onnx/ImageTransformer.class */
public class ImageTransformer implements ExampleTransformer {
    private static final long serialVersionUID = 1;

    @Config(mandatory = true, description = "Image width.")
    private int width;

    @Config(mandatory = true, description = "Image height.")
    private int height;

    @Config(mandatory = true, description = "Number of channels.")
    private int channels;

    private ImageTransformer() {
    }

    public ImageTransformer(int i, int i2, int i3) {
        if (i3 < 1 || i2 < 1 || i < 1) {
            throw new PropertyException("", "Inputs must be positive integers, found [c=" + i + ",h=" + i2 + ",w=" + i3 + "]");
        }
        this.width = i3;
        this.height = i2;
        this.channels = i;
    }

    public void postConfig() {
        if (this.width < 1 || this.height < 1 || this.channels < 1) {
            throw new PropertyException("", "Inputs must be positive integers, found [c=" + this.channels + ",h=" + this.height + ",w=" + this.width + "]");
        }
    }

    private void innerTransform(FloatBuffer floatBuffer, int i, SparseVector sparseVector) {
        VectorIterator it = sparseVector.iterator();
        while (it.hasNext()) {
            VectorTuple vectorTuple = (VectorTuple) it.next();
            floatBuffer.put(vectorTuple.index + i, (float) vectorTuple.value);
        }
    }

    @Override // org.tribuo.interop.onnx.ExampleTransformer
    public OnnxTensor transform(OrtEnvironment ortEnvironment, SparseVector sparseVector) throws OrtException {
        FloatBuffer asFloatBuffer = ByteBuffer.allocateDirect(sparseVector.size() * 4).order(ByteOrder.nativeOrder()).asFloatBuffer();
        innerTransform(asFloatBuffer, 0, sparseVector);
        asFloatBuffer.rewind();
        return OnnxTensor.createTensor(ortEnvironment, asFloatBuffer, new long[]{1, this.channels, this.height, this.width});
    }

    @Override // org.tribuo.interop.onnx.ExampleTransformer
    public OnnxTensor transform(OrtEnvironment ortEnvironment, List<SparseVector> list) throws OrtException {
        if (list.isEmpty()) {
            return OnnxTensor.createTensor(ortEnvironment, FloatBuffer.allocate(0), new long[]{0, this.channels, this.height, this.width});
        }
        int size = list.get(0).size();
        FloatBuffer asFloatBuffer = ByteBuffer.allocateDirect(size * list.size() * 4).order(ByteOrder.nativeOrder()).asFloatBuffer();
        int i = 0;
        for (SparseVector sparseVector : list) {
            innerTransform(asFloatBuffer, i, sparseVector);
            i += sparseVector.size();
            if (sparseVector.size() != size) {
                throw new IllegalArgumentException("Vectors are not all the same dimension, expected " + size + ", found " + sparseVector.size());
            }
        }
        asFloatBuffer.rewind();
        return OnnxTensor.createTensor(ortEnvironment, asFloatBuffer, new long[]{list.size(), this.channels, this.height, this.width});
    }

    public String toString() {
        return "ImageTransformer(channels=" + this.channels + ",height=" + this.height + ",width=" + this.width + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m2getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "ExampleTransformer");
    }
}
