package com.yahoo.schema.processing;

import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.search.DocumentDatabase;
import com.yahoo.vespa.model.search.SearchCluster;
import java.util.Iterator;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.class */
public class RankingExpressionWithOnnxModelTestCase {
    private final Path applicationDir = Path.fromString("src/test/integration/onnx-model/");

    @AfterEach
    public void removeGeneratedModelFiles() {
        IOUtils.recursiveDeleteDir(this.applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
    }

    @Test
    void testOnnxModelFeature() throws Exception {
        VespaModel loadModel = loadModel(this.applicationDir);
        assertTransformedFeature(loadModel);
        assertGeneratedConfig(loadModel);
        Path append = this.applicationDir.append("copy");
        try {
            append.toFile().mkdirs();
            IOUtils.copy(this.applicationDir.append("services.xml").toString(), append.append("services.xml").toString());
            IOUtils.copyDirectory(this.applicationDir.append("schemas").toFile(), append.append("schemas").toFile());
            IOUtils.copyDirectory(this.applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), append.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            VespaModel loadModel2 = loadModel(append);
            assertTransformedFeature(loadModel2);
            assertGeneratedConfig(loadModel2);
            IOUtils.recursiveDeleteDir(append.toFile());
        } catch (Throwable th) {
            IOUtils.recursiveDeleteDir(append.toFile());
            throw th;
        }
    }

    private VespaModel loadModel(Path path) throws Exception {
        return new VespaModel(new DeployState.Builder().applicationPackage(FilesApplicationPackage.fromFile(path.toFile())).build());
    }

    private void assertGeneratedConfig(VespaModel vespaModel) {
        DocumentDatabase documentDatabase = (DocumentDatabase) ((SearchCluster) vespaModel.getSearchClusters().get(0)).getDocumentDbs().get(0);
        RankingConstantsConfig.Builder builder = new RankingConstantsConfig.Builder();
        documentDatabase.getConfig(builder);
        RankingConstantsConfig build = builder.build();
        Assertions.assertEquals(1, build.constant().size());
        Assertions.assertEquals("my_constant", build.constant(0).name());
        Assertions.assertEquals("tensor(d0[2])", build.constant(0).type());
        Assertions.assertEquals("files/constant.json", build.constant(0).fileref().value());
        OnnxModelsConfig.Builder builder2 = new OnnxModelsConfig.Builder();
        documentDatabase.getConfig(builder2);
        OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig(builder2);
        Assertions.assertEquals(6, onnxModelsConfig.model().size());
        Iterator it = onnxModelsConfig.model().iterator();
        while (it.hasNext()) {
            Assertions.assertTrue(((OnnxModelsConfig.Model) it.next()).dry_run_on_setup());
        }
        OnnxModelsConfig.Model model = onnxModelsConfig.model(0);
        Assertions.assertEquals("my_model", model.name());
        Assertions.assertEquals(3, model.input().size());
        Assertions.assertEquals("second/input:0", model.input(0).name());
        Assertions.assertEquals("constant(my_constant)", model.input(0).source());
        Assertions.assertEquals("first_input", model.input(1).name());
        Assertions.assertEquals("attribute(document_field)", model.input(1).source());
        Assertions.assertEquals("third_input", model.input(2).name());
        Assertions.assertEquals("rankingExpression(my_function)", model.input(2).source());
        Assertions.assertEquals(3, model.output().size());
        Assertions.assertEquals("path/to/output:0", model.output(0).name());
        Assertions.assertEquals("out", model.output(0).as());
        Assertions.assertEquals("path/to/output:1", model.output(1).name());
        Assertions.assertEquals("path_to_output_1", model.output(1).as());
        Assertions.assertEquals("path/to/output:2", model.output(2).name());
        Assertions.assertEquals("path_to_output_2", model.output(2).as());
        OnnxModelsConfig.Model model2 = onnxModelsConfig.model(1);
        Assertions.assertEquals("dynamic_model", model2.name());
        Assertions.assertEquals(1, model2.input().size());
        Assertions.assertEquals(1, model2.output().size());
        Assertions.assertEquals("rankingExpression(my_function)", model2.input(0).source());
        OnnxModelsConfig.Model model3 = onnxModelsConfig.model(2);
        Assertions.assertEquals("unbound_model", model3.name());
        Assertions.assertEquals(1, model3.input().size());
        Assertions.assertEquals(1, model3.output().size());
        Assertions.assertEquals("rankingExpression(my_function)", model3.input(0).source());
        OnnxModelsConfig.Model model4 = onnxModelsConfig.model(3);
        Assertions.assertEquals("files_model_onnx", model4.name());
        Assertions.assertEquals(3, model4.input().size());
        Assertions.assertEquals(3, model4.output().size());
        Assertions.assertEquals("path/to/output:0", model4.output(0).name());
        Assertions.assertEquals("path_to_output_0", model4.output(0).as());
        Assertions.assertEquals("path/to/output:1", model4.output(1).name());
        Assertions.assertEquals("path_to_output_1", model4.output(1).as());
        Assertions.assertEquals("path/to/output:2", model4.output(2).name());
        Assertions.assertEquals("path_to_output_2", model4.output(2).as());
        Assertions.assertEquals("files_model_onnx", model4.name());
        OnnxModelsConfig.Model model5 = onnxModelsConfig.model(4);
        Assertions.assertEquals("another_model", model5.name());
        Assertions.assertEquals("third_input", model5.input(2).name());
        Assertions.assertEquals("rankingExpression(another_function)", model5.input(2).source());
        OnnxModelsConfig.Model model6 = onnxModelsConfig.model(5);
        Assertions.assertEquals("files_summary_model_onnx", model6.name());
        Assertions.assertEquals(3, model6.input().size());
        Assertions.assertEquals(3, model6.output().size());
    }

    private void assertTransformedFeature(VespaModel vespaModel) {
        DocumentDatabase documentDatabase = (DocumentDatabase) ((SearchCluster) vespaModel.getSearchClusters().get(0)).getDocumentDbs().get(0);
        RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
        documentDatabase.getConfig(builder);
        RankProfilesConfig rankProfilesConfig = new RankProfilesConfig(builder);
        Assertions.assertEquals(10, rankProfilesConfig.rankprofile().size());
        Assertions.assertEquals("test_model_config", rankProfilesConfig.rankprofile(2).name());
        Assertions.assertEquals("rankingExpression(my_function).rankingScript", rankProfilesConfig.rankprofile(2).fef().property(0).name());
        Assertions.assertEquals("vespa.rank.firstphase", rankProfilesConfig.rankprofile(2).fef().property(2).name());
        Assertions.assertEquals("rankingExpression(firstphase)", rankProfilesConfig.rankprofile(2).fef().property(2).value());
        Assertions.assertEquals("rankingExpression(firstphase).rankingScript", rankProfilesConfig.rankprofile(2).fef().property(3).name());
        Assertions.assertEquals("onnx(my_model).out{d0:1}", rankProfilesConfig.rankprofile(2).fef().property(3).value());
        Assertions.assertEquals("test_generated_model_config", rankProfilesConfig.rankprofile(3).name());
        Assertions.assertEquals("rankingExpression(my_function).rankingScript", rankProfilesConfig.rankprofile(3).fef().property(0).name());
        Assertions.assertEquals("rankingExpression(first_input).rankingScript", rankProfilesConfig.rankprofile(3).fef().property(2).name());
        Assertions.assertEquals("rankingExpression(second_input).rankingScript", rankProfilesConfig.rankprofile(3).fef().property(4).name());
        Assertions.assertEquals("rankingExpression(third_input).rankingScript", rankProfilesConfig.rankprofile(3).fef().property(6).name());
        Assertions.assertEquals("vespa.rank.firstphase", rankProfilesConfig.rankprofile(3).fef().property(8).name());
        Assertions.assertEquals("rankingExpression(firstphase)", rankProfilesConfig.rankprofile(3).fef().property(8).value());
        Assertions.assertEquals("rankingExpression(firstphase).rankingScript", rankProfilesConfig.rankprofile(3).fef().property(9).name());
        Assertions.assertEquals("onnx(files_model_onnx).path_to_output_1{d0:1}", rankProfilesConfig.rankprofile(3).fef().property(9).value());
        Assertions.assertEquals("test_summary_features", rankProfilesConfig.rankprofile(4).name());
        Assertions.assertEquals("vespa.type.feature.onnx(another_model)", rankProfilesConfig.rankprofile(4).fef().property(0).name());
        Assertions.assertEquals("tensor<float>(d0[2])", rankProfilesConfig.rankprofile(4).fef().property(0).value());
        Assertions.assertEquals("rankingExpression(another_function).rankingScript", rankProfilesConfig.rankprofile(4).fef().property(2).name());
        Assertions.assertEquals("rankingExpression(firstphase).rankingScript", rankProfilesConfig.rankprofile(4).fef().property(5).name());
        Assertions.assertEquals("1", rankProfilesConfig.rankprofile(4).fef().property(5).value());
        Assertions.assertEquals("vespa.summary.feature", rankProfilesConfig.rankprofile(4).fef().property(6).name());
        Assertions.assertEquals("onnx(another_model).out", rankProfilesConfig.rankprofile(4).fef().property(6).value());
        Assertions.assertEquals("vespa.summary.feature", rankProfilesConfig.rankprofile(4).fef().property(7).name());
        Assertions.assertEquals("onnx(files_summary_model_onnx).path_to_output_2", rankProfilesConfig.rankprofile(4).fef().property(7).value());
        Assertions.assertEquals("test_dynamic_model", rankProfilesConfig.rankprofile(5).name());
        Assertions.assertEquals("rankingExpression(my_function).rankingScript", rankProfilesConfig.rankprofile(5).fef().property(0).name());
        Assertions.assertEquals("rankingExpression(firstphase).rankingScript", rankProfilesConfig.rankprofile(5).fef().property(3).name());
        Assertions.assertEquals("onnx(dynamic_model).my_output{d0:0, d1:1}", rankProfilesConfig.rankprofile(5).fef().property(3).value());
        Assertions.assertEquals("test_dynamic_model_2", rankProfilesConfig.rankprofile(6).name());
        Assertions.assertEquals("rankingExpression(firstphase).rankingScript", rankProfilesConfig.rankprofile(6).fef().property(5).name());
        Assertions.assertEquals("onnx(dynamic_model).my_output{d0:0, d1:2}", rankProfilesConfig.rankprofile(6).fef().property(5).value());
        Assertions.assertEquals("test_dynamic_model_with_transformer_tokens", rankProfilesConfig.rankprofile(7).name());
        Assertions.assertEquals("rankingExpression(my_function).rankingScript", rankProfilesConfig.rankprofile(7).fef().property(1).name());
        Assertions.assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@a2e4b6abdeb5fb3a) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@a2e4b6abdeb5fb3a) + 1.0 + rankingExpression(__token_length@a2e4b6abdeb5fb3a) + 1.0), 1.0, 0.0))))", rankProfilesConfig.rankprofile(7).fef().property(1).value());
        Assertions.assertEquals("test_unbound_model", rankProfilesConfig.rankprofile(8).name());
        Assertions.assertEquals("rankingExpression(my_function).rankingScript", rankProfilesConfig.rankprofile(8).fef().property(0).name());
        Assertions.assertEquals("rankingExpression(firstphase).rankingScript", rankProfilesConfig.rankprofile(8).fef().property(3).name());
        Assertions.assertEquals("onnx(unbound_model).my_output{d0:0, d1:1}", rankProfilesConfig.rankprofile(8).fef().property(3).value());
    }
}
