package com.yahoo.vespa.model.ml;

import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonNode;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.api.OnnxMemoryStats;
import com.yahoo.io.IOUtils;
import com.yahoo.json.Jackson;
import com.yahoo.path.Path;
import com.yahoo.tensor.TensorType;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.ProcessBuilder;
import java.nio.charset.StandardCharsets;
import java.util.Map;

/* loaded from: input_file:com/yahoo/vespa/model/ml/OnnxModelProbe.class */
public class OnnxModelProbe {
    private static final String binary = "vespa-analyze-onnx-model";

    /* JADX INFO: Access modifiers changed from: package-private */
    public static TensorType probeModel(ApplicationPackage applicationPackage, Path path, String str, Map<String, TensorType> map) {
        TensorType tensorType = TensorType.empty;
        String createContextKey = createContextKey(str, map);
        try {
            tensorType = readProbedOutputType(applicationPackage, path, createContextKey);
            if (tensorType.equals(TensorType.empty) && applicationPackage.getFile(path).exists()) {
                JsonNode callVespaAnalyzeOnnxModel = callVespaAnalyzeOnnxModel(createJsonInput(applicationPackage.getFileReference(path).getAbsolutePath(), map));
                tensorType = outputTypeFromJson(callVespaAnalyzeOnnxModel, str);
                writeMemoryStats(applicationPackage, path, OnnxMemoryStats.fromJson(callVespaAnalyzeOnnxModel));
                if (!tensorType.equals(TensorType.empty)) {
                    writeProbedOutputType(applicationPackage, path, createContextKey, tensorType);
                }
            }
        } catch (IOException | IllegalArgumentException | InterruptedException e) {
        }
        return tensorType;
    }

    private static void writeMemoryStats(ApplicationPackage applicationPackage, Path path, OnnxMemoryStats onnxMemoryStats) throws IOException {
        IOUtils.writeFile(applicationPackage.getFileReference(OnnxMemoryStats.memoryStatsFilePath(path)).getAbsolutePath(), onnxMemoryStats.toJson().toPrettyString(), false);
    }

    private static String createContextKey(String str, Map<String, TensorType> map) {
        StringBuilder append = new StringBuilder().append(str).append(":");
        map.entrySet().stream().sorted(Map.Entry.comparingByKey()).forEachOrdered(entry -> {
            append.append((String) entry.getKey()).append(":").append(entry.getValue()).append(",");
        });
        return append.substring(0, append.length() - 1);
    }

    private static Path probedOutputTypesPath(Path path) {
        return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types");
    }

    static void writeProbedOutputType(ApplicationPackage applicationPackage, Path path, String str, Map<String, TensorType> map, TensorType tensorType) throws IOException {
        writeProbedOutputType(applicationPackage, path, createContextKey(str, map), tensorType);
    }

    private static void writeProbedOutputType(ApplicationPackage applicationPackage, Path path, String str, TensorType tensorType) throws IOException {
        IOUtils.writeFile(applicationPackage.getFileReference(probedOutputTypesPath(path)).getAbsolutePath(), str + "\t" + tensorType + "\n", true);
    }

    private static TensorType readProbedOutputType(ApplicationPackage applicationPackage, Path path, String str) throws IOException {
        String[] split;
        ApplicationFile file = applicationPackage.getFile(probedOutputTypesPath(path));
        if (!file.exists()) {
            return TensorType.empty;
        }
        BufferedReader bufferedReader = new BufferedReader(file.createReader());
        do {
            try {
                String readLine = bufferedReader.readLine();
                if (null == readLine) {
                    bufferedReader.close();
                    return TensorType.empty;
                }
                split = readLine.split("\t");
            } catch (Throwable th) {
                try {
                    bufferedReader.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        } while (!split[0].equals(str));
        TensorType fromSpec = TensorType.fromSpec(split[1]);
        bufferedReader.close();
        return fromSpec;
    }

    private static TensorType outputTypeFromJson(JsonNode jsonNode, String str) throws IOException {
        if (!jsonNode.isObject() || !jsonNode.has("outputs")) {
            return TensorType.empty;
        }
        JsonNode jsonNode2 = jsonNode.get("outputs");
        return !jsonNode2.has(str) ? TensorType.empty : TensorType.fromSpec(jsonNode2.get(str).asText());
    }

    private static String createJsonInput(String str, Map<String, TensorType> map) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        JsonGenerator createGenerator = new JsonFactory().createGenerator(byteArrayOutputStream, JsonEncoding.UTF8);
        createGenerator.writeStartObject();
        createGenerator.writeStringField("model", str);
        createGenerator.writeObjectFieldStart("inputs");
        for (Map.Entry<String, TensorType> entry : map.entrySet()) {
            createGenerator.writeStringField(entry.getKey(), entry.getValue().toString());
        }
        createGenerator.writeEndObject();
        createGenerator.writeEndObject();
        createGenerator.close();
        return byteArrayOutputStream.toString();
    }

    private static JsonNode callVespaAnalyzeOnnxModel(String str) throws IOException, InterruptedException {
        StringBuilder sb = new StringBuilder();
        ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types");
        processBuilder.redirectError(ProcessBuilder.Redirect.DISCARD);
        Process start = processBuilder.start();
        OutputStream outputStream = start.getOutputStream();
        outputStream.write(str.getBytes(StandardCharsets.UTF_8));
        outputStream.close();
        InputStream inputStream = start.getInputStream();
        while (true) {
            int read = inputStream.read();
            if (read == -1) {
                break;
            }
            sb.append((char) read);
        }
        int waitFor = start.waitFor();
        if (waitFor != 0) {
            throw new IllegalArgumentException("Error from 'vespa-analyze-onnx-model'. Return code: " + waitFor + ". Output: '" + sb + "'");
        }
        return Jackson.mapper().readTree(sb.toString());
    }
}
