package com.yahoo.schema.processing;

import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
import com.yahoo.config.model.application.provider.BaseDeployLogger;
import com.yahoo.config.model.application.provider.MockFileRegistry;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.schema.ApplicationBuilder;
import com.yahoo.schema.RankProfile;
import com.yahoo.schema.RankProfileRegistry;
import com.yahoo.schema.expressiontransforms.RankProfileTransformContext;
import com.yahoo.schema.expressiontransforms.TokenTransformer;
import com.yahoo.schema.parser.ParseException;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.Tensor;
import java.util.Map;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:com/yahoo/schema/processing/RankingExpressionWithTransformerTokensTestCase.class */
public class RankingExpressionWithTransformerTokensTestCase {
    @Test
    void testTokenInputIds() throws Exception {
        Assertions.assertEquals(Tensor.from("tensor(d0[1],d1[12]):[101,1,2,102,3,4,5,102,6,7,102,0]"), evaluateExpression("tokenInputIds(12, a, b, c)", "tensor(d0[2]):[1,2]", "tensor(d0[3]):[3,4,5]", "tensor(d0[2]):[6,7]"));
    }

    @Test
    void testTokenInputIdsCustomPadTokens() throws Exception {
        Assertions.assertEquals(Tensor.from("tensor(d0[1],d1[13]):[1,11,12,2,13,14,15,2,16,17,2,0,0]"), evaluateExpression("customTokenInputIds(1, 2, 13, a, b, c)", "tensor(d0[2]):[11,12]", "tensor(d0[3]):[13,14,15]", "tensor(d0[2]):[16,17]"));
    }

    @Test
    void testTokenTypeIds() throws Exception {
        Assertions.assertEquals(Tensor.from("tensor(d0[1],d1[10]):[0,0,0,0,1,1,1,1,0,0]"), evaluateExpression("tokenTypeIds(10, a, b)", "tensor(d0[2]):[1,2]", "tensor(d0[3]):[3,4,5]"));
    }

    @Test
    void testAttentionMask() throws Exception {
        Assertions.assertEquals(Tensor.from("tensor(d0[1],d1[10]):[1,1,1,1,1,1,1,1,0,0]"), evaluateExpression("tokenAttentionMask(10, a, b)", "tensor(d0[2]):[1,2]", "tensor(d0[3]):[3,4,5]"));
    }

    private Tensor evaluateExpression(String str, String str2, String str3) throws Exception {
        return evaluateExpression(str, str2, str3, null, null);
    }

    private Tensor evaluateExpression(String str, String str2, String str3, String str4) throws Exception {
        return evaluateExpression(str, str2, str3, str4, null);
    }

    private Tensor evaluateExpression(String str, String str2, String str3, String str4, String str5) throws Exception {
        MapContext mapContext = new MapContext();
        if (str2 != null) {
            mapContext.put("a", new TensorValue(Tensor.from(str2)));
        }
        if (str3 != null) {
            mapContext.put("b", new TensorValue(Tensor.from(str3)));
        }
        if (str4 != null) {
            mapContext.put("c", new TensorValue(Tensor.from(str4)));
        }
        if (str5 != null) {
            mapContext.put("d", new TensorValue(Tensor.from(str5)));
        }
        RankProfileTransformContext createTransformContext = createTransformContext();
        RankingExpression transform = new TokenTransformer().transform(new RankingExpression(str), createTransformContext);
        for (Map.Entry entry : createTransformContext.rankProfile().getFunctions().entrySet()) {
            mapContext.put((String) entry.getKey(), ((RankProfile.RankingExpressionFunction) entry.getValue()).function().getBody().evaluate(mapContext).asDouble());
        }
        return transform.evaluate(mapContext).asTensor();
    }

    private RankProfileTransformContext createTransformContext() throws ParseException {
        MockApplicationPackage createEmpty = MockApplicationPackage.createEmpty();
        RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
        QueryProfileRegistry queryProfiles = createEmpty.getQueryProfiles();
        ApplicationBuilder applicationBuilder = new ApplicationBuilder(createEmpty, new MockFileRegistry(), new BaseDeployLogger(), new TestProperties(), rankProfileRegistry, queryProfiles);
        applicationBuilder.addSchema("search test {\n  document test {}\n  rank-profile my_profile inherits default {}\n}");
        applicationBuilder.build(true);
        return new RankProfileTransformContext(rankProfileRegistry.get(applicationBuilder.getSchema(), "my_profile"), queryProfiles, Map.of(), (ImportedMlModels) null, Map.of(), Map.of());
    }
}
