package com.yahoo.schema;

import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.ml.OnnxModelInfo;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/* loaded from: input_file:com/yahoo/schema/OnnxModel.class */
public class OnnxModel extends DistributableResource implements Cloneable {
    private OnnxModelInfo modelInfo;
    private final Map<String, String> inputMap;
    private final Map<String, String> outputMap;
    private final Set<String> initializers;
    private String statelessExecutionMode;
    private Integer statelessInterOpThreads;
    private Integer statelessIntraOpThreads;
    private GpuDevice gpuDevice;

    /* loaded from: input_file:com/yahoo/schema/OnnxModel$GpuDevice.class */
    public static final class GpuDevice extends Record {
        private final int deviceNumber;
        private final boolean required;

        public GpuDevice(int i, boolean z) {
            if (i < 0) {
                throw new IllegalArgumentException("deviceNumber cannot be negative, got " + i);
            }
            this.deviceNumber = i;
            this.required = z;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, GpuDevice.class), GpuDevice.class, "deviceNumber;required", "FIELD:Lcom/yahoo/schema/OnnxModel$GpuDevice;->deviceNumber:I", "FIELD:Lcom/yahoo/schema/OnnxModel$GpuDevice;->required:Z").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, GpuDevice.class), GpuDevice.class, "deviceNumber;required", "FIELD:Lcom/yahoo/schema/OnnxModel$GpuDevice;->deviceNumber:I", "FIELD:Lcom/yahoo/schema/OnnxModel$GpuDevice;->required:Z").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, GpuDevice.class, Object.class), GpuDevice.class, "deviceNumber;required", "FIELD:Lcom/yahoo/schema/OnnxModel$GpuDevice;->deviceNumber:I", "FIELD:Lcom/yahoo/schema/OnnxModel$GpuDevice;->required:Z").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public int deviceNumber() {
            return this.deviceNumber;
        }

        public boolean required() {
            return this.required;
        }
    }

    public OnnxModel(String str) {
        super(str);
        this.modelInfo = null;
        this.inputMap = new HashMap();
        this.outputMap = new HashMap();
        this.initializers = new HashSet();
        this.statelessExecutionMode = null;
        this.statelessInterOpThreads = null;
        this.statelessIntraOpThreads = null;
        this.gpuDevice = null;
    }

    public OnnxModel(String str, String str2) {
        super(str, str2);
        this.modelInfo = null;
        this.inputMap = new HashMap();
        this.outputMap = new HashMap();
        this.initializers = new HashSet();
        this.statelessExecutionMode = null;
        this.statelessInterOpThreads = null;
        this.statelessIntraOpThreads = null;
        this.gpuDevice = null;
        validate();
    }

    @Override // com.yahoo.schema.DistributableResource
    /* renamed from: clone */
    public OnnxModel mo14clone() {
        try {
            return (OnnxModel) super.mo14clone();
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException("Clone not supported", e);
        }
    }

    @Override // com.yahoo.schema.DistributableResource
    public void setUri(String str) {
        throw new IllegalArgumentException("URI for ONNX models are not currently supported");
    }

    public void addInputNameMapping(String str, String str2) {
        addInputNameMapping(str, str2, true);
    }

    private String validateInputSource(String str) {
        Optional simple = Reference.simple(str);
        if (!simple.isPresent()) {
            return Reference.fromIdentifier(str).toString();
        }
        Reference reference = (Reference) simple.get();
        if (FeatureNames.isSimpleFeature(reference)) {
            return reference.toString();
        }
        if (reference.isSimpleRankingExpressionWrapper() && reference.simpleArgument().isPresent()) {
            return reference.toString();
        }
        throw new IllegalArgumentException("invalid input for ONNX model " + getName() + ": " + str);
    }

    public void addInputNameMapping(String str, String str2, boolean z) {
        Objects.requireNonNull(str, "Onnx name cannot be null");
        Objects.requireNonNull(str2, "Vespa name cannot be null");
        String validateInputSource = validateInputSource(str2);
        if (z || !this.inputMap.containsKey(str)) {
            this.inputMap.put(str, validateInputSource);
        }
    }

    public void addOutputNameMapping(String str, String str2) {
        addOutputNameMapping(str, str2, true);
    }

    public void addOutputNameMapping(String str, String str2, boolean z) {
        Objects.requireNonNull(str, "Onnx name cannot be null");
        Objects.requireNonNull(str2, "Vespa name cannot be null");
        Reference fromIdentifier = Reference.fromIdentifier(str2);
        if (z || !this.outputMap.containsKey(str)) {
            this.outputMap.put(str, fromIdentifier.toString());
        }
    }

    public void setModelInfo(OnnxModelInfo onnxModelInfo) {
        Objects.requireNonNull(onnxModelInfo, "Onnx model info cannot be null");
        for (String str : onnxModelInfo.getInputs()) {
            addInputNameMapping(str, OnnxModelInfo.asValidIdentifier(str), false);
        }
        for (String str2 : onnxModelInfo.getOutputs()) {
            addOutputNameMapping(str2, OnnxModelInfo.asValidIdentifier(str2), false);
        }
        this.initializers.addAll(onnxModelInfo.getInitializers());
        this.modelInfo = onnxModelInfo;
    }

    public Map<String, String> getInputMap() {
        return Collections.unmodifiableMap(this.inputMap);
    }

    public Map<String, String> getOutputMap() {
        return Collections.unmodifiableMap(this.outputMap);
    }

    public Set<String> getInitializers() {
        return Set.copyOf(this.initializers);
    }

    public String getDefaultOutput() {
        return this.modelInfo != null ? this.modelInfo.getDefaultOutput() : VespaModel.ROOT_CONFIGID;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public TensorType getTensorType(String str, Map<String, TensorType> map) {
        return this.modelInfo != null ? this.modelInfo.getTensorType(str, map) : TensorType.empty;
    }

    public void setStatelessExecutionMode(String str) {
        if ("parallel".equalsIgnoreCase(str)) {
            this.statelessExecutionMode = "parallel";
        } else if ("sequential".equalsIgnoreCase(str)) {
            this.statelessExecutionMode = "sequential";
        }
    }

    public Optional<String> getStatelessExecutionMode() {
        return Optional.ofNullable(this.statelessExecutionMode);
    }

    public void setStatelessInterOpThreads(int i) {
        if (i >= 0) {
            this.statelessInterOpThreads = Integer.valueOf(i);
        }
    }

    public Optional<Integer> getStatelessInterOpThreads() {
        return Optional.ofNullable(this.statelessInterOpThreads);
    }

    public void setStatelessIntraOpThreads(int i) {
        if (i >= 0) {
            this.statelessIntraOpThreads = Integer.valueOf(i);
        }
    }

    public Optional<Integer> getStatelessIntraOpThreads() {
        return Optional.ofNullable(this.statelessIntraOpThreads);
    }

    public void setGpuDevice(int i, boolean z) {
        if (i >= 0) {
            this.gpuDevice = new GpuDevice(i, z);
        }
    }

    public Optional<GpuDevice> getGpuDevice() {
        return Optional.ofNullable(this.gpuDevice);
    }
}
