package com.yahoo.vespa.model.ml;

import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
import ai.vespa.models.handler.ModelsEvaluationHandler;
import com.yahoo.component.ComponentId;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import java.io.IOException;
import java.util.HashMap;
import java.util.Set;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:com/yahoo/vespa/model/ml/ModelEvaluationTest.class */
public class ModelEvaluationTest {
    private final String profile = "rankingExpression(output).rankingScript: onnx(small_constants_and_functions).output\nrankingExpression(output).type: tensor<float>(d0[3])\n";

    @Test
    void testMl_serving_not_activated() {
        Path fromString = Path.fromString("src/test/cfg/application/ml_serving_not_activated");
        try {
            ApplicationContainerCluster applicationContainerCluster = (ApplicationContainerCluster) new ImportedModelTester("ml_serving", fromString).createVespaModel().getContainerClusters().get("container");
            Assertions.assertNull(applicationContainerCluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
            RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
            applicationContainerCluster.getConfig(builder);
            Assertions.assertEquals(0, new RankProfilesConfig(builder).rankprofile().size());
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
        } catch (Throwable th) {
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            throw th;
        }
    }

    @Test
    void testMl_serving() throws IOException {
        Assumptions.assumeTrue(OnnxRuntime.isRuntimeAvailable());
        Path fromString = Path.fromString("src/test/cfg/application/ml_serving");
        Path append = fromString.append("copy");
        try {
            assertHasMlModels(new ImportedModelTester("ml_serving", fromString).createVespaModel(), fromString);
            append.toFile().mkdirs();
            IOUtils.copy(fromString.append("services.xml").toString(), append.append("services.xml").toString());
            IOUtils.copyDirectory(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), append.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            assertHasMlModels(new ImportedModelTester("ml_serving", append).createVespaModel(), fromString);
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            IOUtils.recursiveDeleteDir(append.toFile());
        } catch (Throwable th) {
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            IOUtils.recursiveDeleteDir(append.toFile());
            throw th;
        }
    }

    @Test
    void testContainerSpecificModelSettings() {
        Path fromString = Path.fromString("src/test/cfg/application/onnx_cluster_specific");
        try {
            VespaModel createVespaModel = new ImportedModelTester("mul", fromString).createVespaModel();
            OnnxModelsConfig.Model onnxModelsConfig = getOnnxModelsConfig((ApplicationContainerCluster) createVespaModel.getContainerClusters().get("c1"));
            OnnxModelsConfig.Model onnxModelsConfig2 = getOnnxModelsConfig((ApplicationContainerCluster) createVespaModel.getContainerClusters().get("c2"));
            Assertions.assertEquals(2, onnxModelsConfig.stateless_intraop_threads());
            Assertions.assertEquals(4, onnxModelsConfig2.stateless_intraop_threads());
            Assertions.assertEquals(0, onnxModelsConfig.gpu_device());
            Assertions.assertEquals(1, onnxModelsConfig2.gpu_device());
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
        } catch (Throwable th) {
            IOUtils.recursiveDeleteDir(fromString.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            throw th;
        }
    }

    private OnnxModelsConfig.Model getOnnxModelsConfig(ApplicationContainerCluster applicationContainerCluster) {
        OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder();
        applicationContainerCluster.getConfig(builder);
        return new OnnxModelsConfig(builder).model(0);
    }

    private void assertHasMlModels(VespaModel vespaModel, Path path) {
        ApplicationContainerCluster applicationContainerCluster = (ApplicationContainerCluster) vespaModel.getContainerClusters().get("container");
        Assertions.assertNotNull(applicationContainerCluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
        Assertions.assertNotNull(applicationContainerCluster.getComponentsMap().get(new ComponentId(ModelsEvaluationHandler.class.getName())));
        Assertions.assertTrue(applicationContainerCluster.getHandlers().stream().anyMatch(handler -> {
            return handler.getComponentId().toString().equals(ModelsEvaluationHandler.class.getName());
        }));
        RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
        applicationContainerCluster.getConfig(builder);
        RankProfilesConfig rankProfilesConfig = new RankProfilesConfig(builder);
        RankingConstantsConfig.Builder builder2 = new RankingConstantsConfig.Builder();
        applicationContainerCluster.getConfig(builder2);
        RankingConstantsConfig rankingConstantsConfig = new RankingConstantsConfig(builder2);
        RankingExpressionsConfig.Builder builder3 = new RankingExpressionsConfig.Builder();
        applicationContainerCluster.getConfig(builder3);
        RankingExpressionsConfig build = builder3.build();
        OnnxModelsConfig.Builder builder4 = new OnnxModelsConfig.Builder();
        applicationContainerCluster.getConfig(builder4);
        OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig(builder4);
        Assertions.assertEquals(5, rankProfilesConfig.rankprofile().size());
        Set set = (Set) rankProfilesConfig.rankprofile().stream().map(rankprofile -> {
            return rankprofile.name();
        }).collect(Collectors.toSet());
        Assertions.assertTrue(set.contains("xgboost_2_2"));
        Assertions.assertTrue(set.contains("lightgbm_regression"));
        Assertions.assertTrue(set.contains("add_mul"));
        Assertions.assertTrue(set.contains("small_constants_and_functions"));
        Assertions.assertTrue(set.contains("sqrt"));
        StringBuilder sb = new StringBuilder();
        for (RankProfilesConfig.Rankprofile.Fef.Property property : findProfile("small_constants_and_functions", rankProfilesConfig).property()) {
            sb.append(property.name()).append(": ").append(property.value()).append("\n");
        }
        Assertions.assertEquals("rankingExpression(output).rankingScript: onnx(small_constants_and_functions).output\nrankingExpression(output).type: tensor<float>(d0[3])\n", sb.toString());
        HashMap hashMap = new HashMap();
        for (OnnxModelsConfig.Model model : onnxModelsConfig.model()) {
            hashMap.put(model.fileref().value(), path.append(model.fileref().value()).toFile());
        }
        ModelsEvaluator modelsEvaluator = new ModelsEvaluator(rankProfilesConfig, rankingConstantsConfig, build, onnxModelsConfig, MockFileAcquirer.returnFiles(hashMap));
        Assertions.assertEquals(5, modelsEvaluator.models().size());
        Model model2 = (Model) modelsEvaluator.models().get("xgboost_2_2");
        Assertions.assertNotNull(model2);
        Assertions.assertNotNull(model2.evaluatorOf(new String[0]));
        Assertions.assertNotNull(model2.evaluatorOf(new String[]{"xgboost_2_2"}));
        Model model3 = (Model) modelsEvaluator.models().get("lightgbm_regression");
        Assertions.assertNotNull(model3);
        Assertions.assertNotNull(model3.evaluatorOf(new String[0]));
        Assertions.assertNotNull(model3.evaluatorOf(new String[]{"lightgbm_regression"}));
        Model model4 = (Model) modelsEvaluator.models().get("add_mul");
        Assertions.assertNotNull(model4);
        Assertions.assertEquals(2, model4.functions().size());
        Assertions.assertNotNull(model4.evaluatorOf(new String[]{"output1"}));
        Assertions.assertNotNull(model4.evaluatorOf(new String[]{"output2"}));
        Assertions.assertNotNull(model4.evaluatorOf(new String[]{"default.output1"}));
        Assertions.assertNotNull(model4.evaluatorOf(new String[]{"default.output2"}));
        Assertions.assertNotNull(model4.evaluatorOf(new String[]{"default", "output1"}));
        Assertions.assertNotNull(model4.evaluatorOf(new String[]{"default", "output2"}));
        Assertions.assertNotNull(modelsEvaluator.evaluatorOf("add_mul", new String[]{"output1"}));
        Assertions.assertNotNull(modelsEvaluator.evaluatorOf("add_mul", new String[]{"output2"}));
        Assertions.assertNotNull(modelsEvaluator.evaluatorOf("add_mul", new String[]{"default.output1"}));
        Assertions.assertNotNull(modelsEvaluator.evaluatorOf("add_mul", new String[]{"default.output2"}));
        Assertions.assertNotNull(modelsEvaluator.evaluatorOf("add_mul", new String[]{"default", "output1"}));
        Assertions.assertNotNull(modelsEvaluator.evaluatorOf("add_mul", new String[]{"default", "output2"}));
        Assertions.assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), ((ExpressionFunction) model4.functions().get(0)).getArgumentType("input1"));
        Assertions.assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), ((ExpressionFunction) model4.functions().get(0)).getArgumentType("input2"));
        Model model5 = (Model) modelsEvaluator.models().get("sqrt");
        Assertions.assertNotNull(model5);
        Assertions.assertEquals(1, model5.functions().size());
        Assertions.assertNotNull(model5.evaluatorOf(new String[0]));
        Assertions.assertNotNull(model5.evaluatorOf(new String[]{"out_layer_1_1"}));
        Assertions.assertNotNull(modelsEvaluator.evaluatorOf("sqrt", new String[0]));
        Assertions.assertNotNull(modelsEvaluator.evaluatorOf("sqrt", new String[]{"out_layer_1_1"}));
        Assertions.assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), ((ExpressionFunction) model5.functions().get(0)).getArgumentType("input"));
    }

    private RankProfilesConfig.Rankprofile.Fef findProfile(String str, RankProfilesConfig rankProfilesConfig) {
        for (RankProfilesConfig.Rankprofile rankprofile : rankProfilesConfig.rankprofile()) {
            if (rankprofile.name().equals(str)) {
                return rankprofile.fef();
            }
        }
        throw new IllegalArgumentException("No profile named " + str);
    }
}
