package com.facebook.presto.sql.gen;

import com.facebook.presto.bytecode.Access;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.ParameterizedType;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.DriverYieldSignal;
import com.facebook.presto.operator.index.PageRecordSet;
import com.facebook.presto.operator.project.CursorProcessor;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.gen.CommonSubExpressionRewriter;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.testing.TestingConnectorSession;
import com.facebook.presto.util.CompilerUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/gen/TestCursorProcessorCompiler.class */
public class TestCursorProcessorCompiler {
    private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
    private static final FunctionAndTypeManager FUNCTION_MANAGER = METADATA.getFunctionAndTypeManager();
    private static final CallExpression ADD_X_Y = Expressions.call(OperatorType.ADD.name(), FUNCTION_MANAGER.resolveOperator(OperatorType.ADD, TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, BigintType.BIGINT})), BigintType.BIGINT, new RowExpression[]{Expressions.field(0, BigintType.BIGINT), Expressions.field(1, BigintType.BIGINT)});
    private static final CallExpression ADD_X_Y_GREATER_THAN_2 = Expressions.call(OperatorType.GREATER_THAN.name(), FUNCTION_MANAGER.resolveOperator(OperatorType.GREATER_THAN, TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, BigintType.BIGINT})), BooleanType.BOOLEAN, new RowExpression[]{ADD_X_Y, Expressions.constant(2L, BigintType.BIGINT)});
    private static final CallExpression ADD_X_Y_LESS_THAN_10 = Expressions.call(OperatorType.LESS_THAN.name(), FUNCTION_MANAGER.resolveOperator(OperatorType.LESS_THAN, TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, BigintType.BIGINT})), BooleanType.BOOLEAN, new RowExpression[]{ADD_X_Y, Expressions.constant(10L, BigintType.BIGINT)});
    private static final CallExpression ADD_X_Y_Z = Expressions.call(OperatorType.ADD.name(), FUNCTION_MANAGER.resolveOperator(OperatorType.ADD, TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, BigintType.BIGINT})), BigintType.BIGINT, new RowExpression[]{Expressions.call(OperatorType.ADD.name(), FUNCTION_MANAGER.resolveOperator(OperatorType.ADD, TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, BigintType.BIGINT})), BigintType.BIGINT, new RowExpression[]{Expressions.field(0, BigintType.BIGINT), Expressions.field(1, BigintType.BIGINT)}), Expressions.field(2, BigintType.BIGINT)});

    @Test
    public void testRewriteRowExpressionWithCSE() {
        CursorProcessorCompiler cursorProcessorCompiler = new CursorProcessorCompiler(METADATA, true, Collections.emptyMap());
        ClassDefinition classDefinition = new ClassDefinition(Access.a(new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName(CursorProcessor.class.getSimpleName()), ParameterizedType.type(Object.class), new ParameterizedType[]{ParameterizedType.type(CursorProcessor.class)});
        SpecialFormExpression specialFormExpression = new SpecialFormExpression(SpecialFormExpression.Form.AND, BigintType.BIGINT, new RowExpression[]{ADD_X_Y_GREATER_THAN_2});
        ImmutableList of = ImmutableList.of(ADD_X_Y_Z);
        Map collectCSEByLevel = CommonSubExpressionRewriter.collectCSEByLevel(ImmutableList.builder().addAll(of).add(specialFormExpression).build());
        Map declareCommonSubExpressionFields = CommonSubExpressionRewriter.CommonSubExpressionFields.declareCommonSubExpressionFields(classDefinition, collectCSEByLevel);
        Map map = (Map) collectCSEByLevel.values().stream().flatMap(map2 -> {
            return map2.entrySet().stream();
        }).collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }));
        Assert.assertEquals(1, declareCommonSubExpressionFields.size());
        VariableReferenceExpression variableReferenceExpression = (VariableReferenceExpression) declareCommonSubExpressionFields.keySet().iterator().next();
        SpecialFormExpression specialFormExpression2 = (RowExpression) cursorProcessorCompiler.rewriteRowExpressionsWithCSE(ImmutableList.of(specialFormExpression), map).get(0);
        Assert.assertTrue(((CallExpression) cursorProcessorCompiler.rewriteRowExpressionsWithCSE(of, map).get(0)).getArguments().contains(variableReferenceExpression));
        Assert.assertTrue(((CallExpression) specialFormExpression2.getArguments().get(0)).getArguments().contains(variableReferenceExpression));
    }

    @Test
    public void testCompilerWithCSE() {
        ExpressionCompiler expressionCompiler = new ExpressionCompiler(METADATA, new PageFunctionCompiler(METADATA, 0));
        SpecialFormExpression specialFormExpression = new SpecialFormExpression(SpecialFormExpression.Form.AND, BigintType.BIGINT, new RowExpression[]{ADD_X_Y_GREATER_THAN_2, ADD_X_Y_LESS_THAN_10});
        List<? extends RowExpression> createIfProjectionList = createIfProjectionList(5);
        Supplier compileCursorProcessor = expressionCompiler.compileCursorProcessor(TestingConnectorSession.SESSION.getSqlFunctionProperties(), Optional.of(specialFormExpression), createIfProjectionList, "key", true);
        Supplier compileCursorProcessor2 = expressionCompiler.compileCursorProcessor(TestingConnectorSession.SESSION.getSqlFunctionProperties(), Optional.of(specialFormExpression), createIfProjectionList, "key", false);
        Page createLongBlockPage = createLongBlockPage(2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
        ImmutableList of = ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT);
        PageBuilder pageBuilder = new PageBuilder((List) createIfProjectionList.stream().map((v0) -> {
            return v0.getType();
        }).collect(Collectors.toList()));
        PageRecordSet pageRecordSet = new PageRecordSet(of, createLongBlockPage);
        ((CursorProcessor) compileCursorProcessor.get()).process(TestingConnectorSession.SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), pageRecordSet.cursor(), pageBuilder);
        Page build = pageBuilder.build();
        pageBuilder.reset();
        ((CursorProcessor) compileCursorProcessor2.get()).process(TestingConnectorSession.SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), pageRecordSet.cursor(), pageBuilder);
        checkPageEqual(build, pageBuilder.build());
    }

    private static Page createLongBlockPage(int i, long... jArr) {
        Block[] blockArr = new Block[i];
        for (int i2 = 0; i2 < i; i2++) {
            BlockBuilder createFixedSizeBlockBuilder = BigintType.BIGINT.createFixedSizeBlockBuilder(jArr.length);
            for (long j : jArr) {
                BigintType.BIGINT.writeLong(createFixedSizeBlockBuilder, j);
            }
            blockArr[i2] = createFixedSizeBlockBuilder.build();
        }
        return new Page(blockArr);
    }

    private List<? extends RowExpression> createIfProjectionList(int i) {
        return (List) IntStream.range(0, i).mapToObj(i2 -> {
            return new SpecialFormExpression(SpecialFormExpression.Form.IF, BigintType.BIGINT, new RowExpression[]{Expressions.call(OperatorType.GREATER_THAN.name(), FUNCTION_MANAGER.resolveOperator(OperatorType.GREATER_THAN, TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, BigintType.BIGINT})), BooleanType.BOOLEAN, new RowExpression[]{ADD_X_Y, Expressions.constant(8L, BigintType.BIGINT)}), Expressions.constant(Long.valueOf(i2), BigintType.BIGINT), Expressions.constant(Long.valueOf(i2 + 1), BigintType.BIGINT)});
        }).collect(ImmutableList.toImmutableList());
    }

    private void checkBlockEqual(Block block, Block block2) {
        Assert.assertEquals(block.getPositionCount(), block2.getPositionCount());
        for (int i = 0; i < block.getPositionCount(); i++) {
            Assert.assertEquals(block.getLong(i), block2.getLong(i));
        }
    }

    private void checkPageEqual(Page page, Page page2) {
        Assert.assertEquals(page.getPositionCount(), page2.getPositionCount());
        for (int i = 0; i < page.getPositionCount(); i++) {
            checkBlockEqual(page.getBlock(i), page2.getBlock(i));
        }
    }
}
