package com.yahoo.schema.processing;

import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.config.provision.TenantName;
import com.yahoo.io.IOUtils;
import com.yahoo.io.reader.NamedReader;
import com.yahoo.path.Path;
import com.yahoo.schema.FeatureNames;
import com.yahoo.schema.RankProfile;
import com.yahoo.schema.parser.ParseException;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.tensor.TensorType;
import com.yahoo.yolean.Exceptions;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.class */
public class RankingExpressionWithOnnxTestCase {
    private final Path applicationDir = Path.fromString("src/test/integration/onnx/");
    private static final String name = "mnist_softmax";
    private static final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(mnist_softmax_layer_Variable_1) * 1.0, f(a,b)(a + b))";
    private static final String vespaExpressionConstants = "constant mnist_softmax_layer_Variable { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }\nconstant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }\n";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase$StoringApplicationPackage.class */
    public static class StoringApplicationPackage extends MockApplicationPackage {
        /* JADX INFO: Access modifiers changed from: package-private */
        public StoringApplicationPackage(Path path) {
            this(path, null, null);
        }

        StoringApplicationPackage(Path path, String str, String str2) {
            super(new File(path.toString()), (String) null, (String) null, List.of(), Map.of(), (String) null, (String) null, (String) null, false, str, str2, TenantName.defaultName());
        }

        public ApplicationFile getFile(Path path) {
            return new MockApplicationPackage.MockApplicationFile(path, root());
        }

        public List<NamedReader> getFiles(Path path, String str) {
            File[] listFiles = getFileReference(path).listFiles();
            if (listFiles == null) {
                return List.of();
            }
            ArrayList arrayList = new ArrayList();
            for (File file : listFiles) {
                if (file.getName().endsWith(str)) {
                    try {
                        arrayList.add(new NamedReader(file.getName(), new FileReader(file)));
                    } catch (IOException e) {
                        throw new UncheckedIOException(e);
                    }
                }
            }
            return arrayList;
        }
    }

    @AfterEach
    public void removeGeneratedModelFiles() {
        IOUtils.recursiveDeleteDir(this.applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
    }

    @Test
    void testOnnxReferenceWithConstantFeature() {
        fixtureWith("constant(mytensor)", "onnx_vespa('mnist_softmax.onnx')", "constant mnist_softmax_layer_Variable { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }\nconstant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }\nconstant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", null).assertFirstPhaseExpression(vespaExpression, "my_profile");
    }

    @Test
    void testOnnxReferenceWithQueryFeature() {
        fixtureWith("query(mytensor)", "onnx_vespa('mnist_softmax.onnx')", vespaExpressionConstants, null, "Placeholder", new StoringApplicationPackage(this.applicationDir, "<query-profile id='default' type='root'/>", "<query-profile-type id='root'>  <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[1],d1[784])'/></query-profile-type>")).assertFirstPhaseExpression(vespaExpression, "my_profile");
    }

    @Test
    void testOnnxReferenceWithDocumentFeature() {
        fixtureWith("attribute(mytensor)", "onnx_vespa('mnist_softmax.onnx')", vespaExpressionConstants, "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", new StoringApplicationPackage(this.applicationDir)).assertFirstPhaseExpression(vespaExpression, "my_profile");
    }

    @Test
    void testOnnxReferenceWithFeatureCombination() {
        fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)", "onnx_vespa('mnist_softmax.onnx')", "constant mnist_softmax_layer_Variable { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }\nconstant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }\nconstant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", new StoringApplicationPackage(this.applicationDir, "<query-profile id='default' type='root'/>", "<query-profile-type id='root'>  <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[1],d1[784],d2[10])'/></query-profile-type>")).assertFirstPhaseExpression(vespaExpression, "my_profile");
    }

    @Test
    void testNestedOnnxReference() {
        fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", "5 + sum(onnx_vespa('mnist_softmax.onnx'))", vespaExpressionConstants).assertFirstPhaseExpression("5 + reduce(join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(mnist_softmax_layer_Variable_1) * 1.0, f(a,b)(a + b)), sum)", "my_profile");
    }

    @Test
    void testOnnxReferenceWithSpecifiedOutput() {
        fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", "onnx_vespa('mnist_softmax.onnx', 'layer_add')", vespaExpressionConstants).assertFirstPhaseExpression(vespaExpression, "my_profile");
    }

    @Test
    void testOnnxReferenceWithSpecifiedOutputAndSignature() {
        fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')", vespaExpressionConstants).assertFirstPhaseExpression(vespaExpression, "my_profile");
    }

    @Test
    void testOnnxReferenceMissingFunction() throws ParseException {
        try {
            RankProfileSearchFixture rankProfileSearchFixture = new RankProfileSearchFixture(new StoringApplicationPackage(this.applicationDir), new QueryProfileRegistry(), "  rank-profile my_profile {\n    first-phase {\n      expression: onnx_vespa('mnist_softmax.onnx')    }\n  }");
            rankProfileSearchFixture.compileRankProfile("my_profile", this.applicationDir.append("models"));
            rankProfileSearchFixture.assertFirstPhaseExpression(vespaExpression, "my_profile");
            Assertions.fail("Expecting exception");
        } catch (IllegalArgumentException e) {
            Assertions.assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from onnx_vespa(\"mnist_softmax.onnx\"): Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is not present in rank profile 'my_profile'", Exceptions.toMessageString(e));
        }
    }

    @Test
    void testOnnxReferenceWithWrongFunctionType() {
        try {
            fixtureWith("tensor(d0[1],d5[10])(0.0)", "onnx_vespa('mnist_softmax.onnx')").assertFirstPhaseExpression(vespaExpression, "my_profile");
            Assertions.fail("Expecting exception");
        } catch (IllegalArgumentException e) {
            Assertions.assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from onnx_vespa(\"mnist_softmax.onnx\"): Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), but this function returns tensor(d0[1],d5[10])", Exceptions.toMessageString(e));
        }
    }

    @Test
    void testOnnxReferenceSpecifyingNonExistingOutput() {
        try {
            fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", "onnx_vespa('mnist_softmax.onnx', 'y')").assertFirstPhaseExpression(vespaExpression, "my_profile");
            Assertions.fail("Expecting exception");
        } catch (IllegalArgumentException e) {
            Assertions.assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from onnx_vespa(\"mnist_softmax.onnx\",\"y\"): No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add", Exceptions.toMessageString(e));
        }
    }

    @Test
    void testImportingFromStoredExpressions() throws IOException {
        fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", "onnx_vespa(\"mnist_softmax.onnx\")", vespaExpressionConstants).assertFirstPhaseExpression(vespaExpression, "my_profile");
        Path append = this.applicationDir.getParentPath().append("copy");
        try {
            append.toFile().mkdirs();
            IOUtils.copyDirectory(this.applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), append.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", "onnx_vespa('mnist_softmax.onnx')", "constant mnist_softmax_layer_Variable { file: ignored\ntype: tensor<float>(d0[2],d1[784]) }\nconstant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[2],d1[784]) }\n", null, "Placeholder", new StoringApplicationPackage(append)).assertFirstPhaseExpression(vespaExpression, "my_profile");
            IOUtils.recursiveDeleteDir(append.toFile());
        } catch (Throwable th) {
            IOUtils.recursiveDeleteDir(append.toFile());
            throw th;
        }
    }

    @Test
    void testImportingFromStoredExpressionsWithFunctionOverridingConstantAndInheritance() throws IOException {
        RankProfileSearchFixture uncompiledFixtureWith = uncompiledFixtureWith("  rank-profile my_profile {\n    function Placeholder() {\n      expression: tensor<float>(d0[1],d1[784])(0.0)\n    }\n    function mnist_softmax_layer_Variable() {\n      expression: tensor<float>(d1[10],d2[784])(0.0)\n    }\n    first-phase {\n      expression: onnx_vespa('mnist_softmax.onnx')    }\n  }  rank-profile my_profile_child inherits my_profile {\n  }", new StoringApplicationPackage(this.applicationDir), "constant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[1],d1[10]) }\n");
        uncompiledFixtureWith.compileRankProfile("my_profile", this.applicationDir.append("models"));
        uncompiledFixtureWith.compileRankProfile("my_profile_child", this.applicationDir.append("models"));
        uncompiledFixtureWith.assertFirstPhaseExpression("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_layer_Variable, f(a,b)(a * b)), sum, d2) * 1.0, constant(mnist_softmax_layer_Variable_1) * 1.0, f(a,b)(a + b))", "my_profile");
        uncompiledFixtureWith.assertFirstPhaseExpression("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_layer_Variable, f(a,b)(a * b)), sum, d2) * 1.0, constant(mnist_softmax_layer_Variable_1) * 1.0, f(a,b)(a + b))", "my_profile_child");
        Assertions.assertNull(uncompiledFixtureWith.search().constants().get("mnist_softmax_Variable"), "Constant overridden by function is not added");
        Path append = this.applicationDir.getParentPath().append("copy");
        try {
            append.toFile().mkdirs();
            IOUtils.copyDirectory(this.applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), append.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            RankProfileSearchFixture uncompiledFixtureWith2 = uncompiledFixtureWith("  rank-profile my_profile {\n    function Placeholder() {\n      expression: tensor<float>(d0[1],d1[784])(0.0)\n    }\n    function mnist_softmax_layer_Variable() {\n      expression: tensor<float>(d1[10],d2[784])(0.0)\n    }\n    first-phase {\n      expression: onnx_vespa('mnist_softmax.onnx')    }\n  }  rank-profile my_profile_child inherits my_profile {\n  }", new StoringApplicationPackage(append), "constant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor<float>(d0[1],d1[10]) }\n");
            uncompiledFixtureWith2.compileRankProfile("my_profile", this.applicationDir.append("models"));
            uncompiledFixtureWith2.compileRankProfile("my_profile_child", this.applicationDir.append("models"));
            uncompiledFixtureWith2.assertFirstPhaseExpression("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_layer_Variable, f(a,b)(a * b)), sum, d2) * 1.0, constant(mnist_softmax_layer_Variable_1) * 1.0, f(a,b)(a + b))", "my_profile");
            uncompiledFixtureWith2.assertFirstPhaseExpression("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_layer_Variable, f(a,b)(a * b)), sum, d2) * 1.0, constant(mnist_softmax_layer_Variable_1) * 1.0, f(a,b)(a + b))", "my_profile_child");
            Assertions.assertNull(uncompiledFixtureWith2.search().constants().get("mnist_softmax_Variable"), "Constant overridden by function is not added");
            IOUtils.recursiveDeleteDir(append.toFile());
        } catch (Throwable th) {
            IOUtils.recursiveDeleteDir(append.toFile());
            throw th;
        }
    }

    @Test
    void testFunctionGeneration() {
        RankProfileSearchFixture uncompiledFixtureWith = uncompiledFixtureWith("  rank-profile my_profile {\n    function input() {\n      expression: tensor<float>(d0[3])(0.0)\n    }\n    first-phase {\n      expression: onnx_vespa('small_constants_and_functions.onnx')    }\n  }", new StoringApplicationPackage(this.applicationDir));
        uncompiledFixtureWith.compileRankProfile("my_profile", this.applicationDir.append("models"));
        uncompiledFixtureWith.assertFirstPhaseExpression("join(imported_ml_function_small_constants_and_functions_exp_output, reduce(join(join(reduce(imported_ml_function_small_constants_and_functions_exp_output, sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(small_constants_and_functions_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))", "my_profile");
        uncompiledFixtureWith.assertFunction("map(input, f(a)(exp(a)))", "imported_ml_function_small_constants_and_functions_exp_output", "my_profile");
    }

    @Test
    void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException {
        RankProfileSearchFixture uncompiledFixtureWith = uncompiledFixtureWith("  rank-profile my_profile {\n    function input() {\n      expression: tensor<float>(d0[3])(0.0)\n    }\n    first-phase {\n      expression: onnx_vespa('small_constants_and_functions.onnx')    }\n  }  rank-profile my_profile_child inherits my_profile {\n  }", new StoringApplicationPackage(this.applicationDir));
        uncompiledFixtureWith.compileRankProfile("my_profile", this.applicationDir.append("models"));
        uncompiledFixtureWith.compileRankProfile("my_profile_child", this.applicationDir.append("models"));
        uncompiledFixtureWith.assertFirstPhaseExpression("join(imported_ml_function_small_constants_and_functions_exp_output, reduce(join(join(reduce(imported_ml_function_small_constants_and_functions_exp_output, sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(small_constants_and_functions_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))", "my_profile");
        uncompiledFixtureWith.assertFirstPhaseExpression("join(imported_ml_function_small_constants_and_functions_exp_output, reduce(join(join(reduce(imported_ml_function_small_constants_and_functions_exp_output, sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(small_constants_and_functions_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))", "my_profile_child");
        assertSmallConstant("small_constants_and_functions_epsilon", TensorType.fromSpec("tensor()"), uncompiledFixtureWith);
        uncompiledFixtureWith.assertFunction("map(input, f(a)(exp(a)))", "imported_ml_function_small_constants_and_functions_exp_output", "my_profile");
        uncompiledFixtureWith.assertFunction("map(input, f(a)(exp(a)))", "imported_ml_function_small_constants_and_functions_exp_output", "my_profile_child");
        Path append = this.applicationDir.getParentPath().append("copy");
        try {
            append.toFile().mkdirs();
            IOUtils.copyDirectory(this.applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), append.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
            RankProfileSearchFixture uncompiledFixtureWith2 = uncompiledFixtureWith("  rank-profile my_profile {\n    function input() {\n      expression: tensor<float>(d0[3])(0.0)\n    }\n    first-phase {\n      expression: onnx_vespa('small_constants_and_functions.onnx')    }\n  }  rank-profile my_profile_child inherits my_profile {\n  }", new StoringApplicationPackage(append));
            uncompiledFixtureWith2.compileRankProfile("my_profile", this.applicationDir.append("models"));
            uncompiledFixtureWith2.compileRankProfile("my_profile_child", this.applicationDir.append("models"));
            uncompiledFixtureWith2.assertFirstPhaseExpression("join(imported_ml_function_small_constants_and_functions_exp_output, reduce(join(join(reduce(imported_ml_function_small_constants_and_functions_exp_output, sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(small_constants_and_functions_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))", "my_profile");
            uncompiledFixtureWith2.assertFirstPhaseExpression("join(imported_ml_function_small_constants_and_functions_exp_output, reduce(join(join(reduce(imported_ml_function_small_constants_and_functions_exp_output, sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), constant(small_constants_and_functions_epsilon), f(a,b)(a + b)), sum, d0), f(a,b)(a / b))", "my_profile_child");
            assertSmallConstant("small_constants_and_functions_epsilon", TensorType.fromSpec("tensor()"), uncompiledFixtureWith);
            uncompiledFixtureWith2.assertFunction("map(input, f(a)(exp(a)))", "imported_ml_function_small_constants_and_functions_exp_output", "my_profile");
            uncompiledFixtureWith2.assertFunction("map(input, f(a)(exp(a)))", "imported_ml_function_small_constants_and_functions_exp_output", "my_profile_child");
            IOUtils.recursiveDeleteDir(append.toFile());
        } catch (Throwable th) {
            IOUtils.recursiveDeleteDir(append.toFile());
            throw th;
        }
    }

    private void assertSmallConstant(String str, TensorType tensorType, RankProfileSearchFixture rankProfileSearchFixture) {
        RankProfile.Constant constant = (RankProfile.Constant) rankProfileSearchFixture.compiledRankProfile("my_profile").constants().get(FeatureNames.asConstantFeature(str));
        Assertions.assertNotNull(constant);
        Assertions.assertEquals(tensorType, constant.type());
    }

    private RankProfileSearchFixture fixtureWith(String str, String str2) {
        return fixtureWith(str, str2, null);
    }

    private RankProfileSearchFixture fixtureWith(String str, String str2, String str3) {
        return fixtureWith(str, str2, str3, null, "Placeholder", new StoringApplicationPackage(this.applicationDir));
    }

    private RankProfileSearchFixture fixtureWith(String str, String str2, String str3, String str4) {
        return fixtureWith(str, str2, str3, str4, "Placeholder", new StoringApplicationPackage(this.applicationDir));
    }

    private RankProfileSearchFixture uncompiledFixtureWith(String str, StoringApplicationPackage storingApplicationPackage) {
        return uncompiledFixtureWith(str, storingApplicationPackage, null);
    }

    private RankProfileSearchFixture uncompiledFixtureWith(String str, StoringApplicationPackage storingApplicationPackage, String str2) {
        try {
            return new RankProfileSearchFixture(storingApplicationPackage, storingApplicationPackage.getQueryProfiles(), str, str2, null);
        } catch (ParseException e) {
            throw new IllegalArgumentException((Throwable) e);
        }
    }

    private RankProfileSearchFixture fixtureWith(String str, String str2, String str3, String str4, String str5, StoringApplicationPackage storingApplicationPackage) {
        try {
            RankProfileSearchFixture rankProfileSearchFixture = new RankProfileSearchFixture(storingApplicationPackage, storingApplicationPackage.getQueryProfiles(), "  rank-profile my_profile {\n    function " + str5 + "() {\n      expression: " + str + "    }\n    first-phase {\n      expression: " + str2 + "    }\n  }", str3, str4);
            rankProfileSearchFixture.compileRankProfile("my_profile", this.applicationDir.append("models"));
            return rankProfileSearchFixture;
        } catch (ParseException e) {
            throw new IllegalArgumentException((Throwable) e);
        }
    }
}
