package com.facebook.presto.hive;

import com.facebook.presto.Session;
import com.facebook.presto.testing.QueryRunner;
import com.facebook.presto.tests.AbstractTestQueryFramework;
import com.facebook.presto.tests.DistributedQueryRunner;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableMap;
import io.airlift.tpch.TpchTable;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/hive/TestHivePushdownFilterQueries.class */
public class TestHivePushdownFilterQueries extends AbstractTestQueryFramework {
    private static final Pattern ARRAY_SUBSCRIPT_PATTERN = Pattern.compile("([a-z_+]+)((\\[[0-9]+\\])+)");
    private static final String WITH_LINEITEM_EX = "WITH lineitem_ex AS (\nSELECT linenumber, orderkey, partkey, suppkey, \n   CASE WHEN linenumber % 5 = 0 THEN null ELSE shipmode = 'AIR' END AS ship_by_air, \n   CASE WHEN linenumber % 7 = 0 THEN null ELSE returnflag = 'R' END AS is_returned, \n   CASE WHEN linenumber % 4 = 0 THEN null ELSE CAST(day(shipdate) AS TINYINT) END AS ship_day,    CASE WHEN linenumber % 6 = 0 THEN null ELSE CAST(month(shipdate) AS TINYINT) END AS ship_month,    CASE WHEN linenumber % 3 = 0 THEN null ELSE CAST(shipdate AS TIMESTAMP) END AS ship_timestamp, \n   CASE WHEN orderkey % 3 = 0 THEN null ELSE CAST(commitdate AS TIMESTAMP) END AS commit_timestamp, \n   CASE WHEN orderkey % 11 = 0 THEN null ELSE (orderkey, partkey, suppkey) END AS keys, \n   CASE WHEN orderkey % 13 = 0 THEN null ELSE ((orderkey, partkey), (suppkey,), CASE WHEN orderkey % 17 = 0 THEN null ELSE (orderkey, partkey) END) END AS nested_keys, \n   CASE WHEN orderkey % 17 = 0 THEN null ELSE (shipmode = 'AIR', returnflag = 'R') END as flags\nFROM lineitem)\n";

    protected TestHivePushdownFilterQueries() {
        super(TestHivePushdownFilterQueries::createQueryRunner);
    }

    private static QueryRunner createQueryRunner() throws Exception {
        DistributedQueryRunner createQueryRunner = HiveQueryRunner.createQueryRunner(TpchTable.getTables(), ImmutableMap.of(), "sql-standard", ImmutableMap.of("hive.pushdown-filter-enabled", "true"), Optional.empty());
        createQueryRunner.execute(noPushdownFilter(createQueryRunner.getDefaultSession()), "CREATE TABLE lineitem_ex (linenumber, orderkey, partkey, suppkey, ship_by_air, is_returned, ship_day, ship_month, ship_timestamp, commit_timestamp, keys, nested_keys, flags) AS SELECT linenumber, orderkey, partkey, suppkey,    IF (linenumber % 5 = 0, null, shipmode = 'AIR') AS ship_by_air,    IF (linenumber % 7 = 0, null, returnflag = 'R') AS is_returned,    IF (linenumber % 4 = 0, null, CAST(day(shipdate) AS TINYINT)) AS ship_day,    IF (linenumber % 6 = 0, null, CAST(month(shipdate) AS TINYINT)) AS ship_month,    IF (linenumber % 3 = 0, null, CAST(shipdate AS TIMESTAMP)) AS ship_timestamp,    IF (orderkey % 3 = 0, null, CAST(commitdate AS TIMESTAMP)) AS commit_timestamp,    IF (orderkey % 11 = 0, null, ARRAY[orderkey, partkey, suppkey]),    IF (orderkey % 13 = 0, null, ARRAY[ARRAY[orderkey, partkey], ARRAY[suppkey], IF (orderkey % 17 = 0, null, ARRAY[orderkey, partkey])]),    IF (orderkey % 17 = 0, null, ARRAY[shipmode = 'AIR', returnflag = 'R']) FROM lineitem");
        return createQueryRunner;
    }

    @Test
    public void testBooleans() {
        assertQueryUsingH2Cte("SELECT is_returned FROM lineitem_ex");
        assertQueryUsingH2Cte("SELECT is_returned FROM lineitem_ex WHERE is_returned = true");
        assertQueryUsingH2Cte("SELECT count(*) FROM lineitem_ex WHERE is_returned is not null");
        assertQueryUsingH2Cte("SELECT count(*) FROM lineitem_ex WHERE is_returned = false");
        assertQueryUsingH2Cte("SELECT ship_by_air, is_returned FROM lineitem_ex");
        assertQueryUsingH2Cte("SELECT ship_by_air, is_returned FROM lineitem_ex WHERE ship_by_air = true");
        assertQueryUsingH2Cte("SELECT ship_by_air, is_returned FROM lineitem_ex WHERE ship_by_air = true AND is_returned = false");
        assertQueryUsingH2Cte("SELECT COUNT(*) FROM lineitem_ex WHERE ship_by_air is null");
        assertQueryUsingH2Cte("SELECT COUNT(*) FROM lineitem_ex WHERE ship_by_air is not null AND is_returned = true");
    }

    @Test
    public void testBytes() {
        assertQueryUsingH2Cte("SELECT ship_day FROM lineitem_ex");
        assertQueryUsingH2Cte("SELECT ship_day FROM lineitem_ex WHERE ship_day < 15");
        assertQueryUsingH2Cte("SELECT count(*) FROM lineitem_ex WHERE ship_day > 15");
        assertQueryUsingH2Cte("SELECT count(*) FROM lineitem_ex WHERE ship_day is null");
        assertQueryUsingH2Cte("SELECT ship_day, ship_month FROM lineitem_ex");
        assertQueryUsingH2Cte("SELECT ship_day, ship_month FROM lineitem_ex WHERE ship_month = 1");
        assertQueryUsingH2Cte("SELECT ship_day, ship_month FROM lineitem_ex WHERE ship_day = 1 AND ship_month = 1");
        assertQueryUsingH2Cte("SELECT COUNT(*) FROM lineitem_ex WHERE ship_month is null");
        assertQueryUsingH2Cte("SELECT COUNT(*) FROM lineitem_ex WHERE ship_day is not null AND ship_month = 1");
        assertQueryUsingH2Cte("SELECT ship_day, ship_month FROM lineitem_ex WHERE ship_day > 15 AND ship_month < 5 AND (ship_day + ship_month) < 20");
    }

    @Test
    public void testNumeric() {
        assertQuery("SELECT orderkey, custkey, orderdate, shippriority FROM orders");
        assertQuery("SELECT count(*) FROM orders WHERE orderkey BETWEEN 100 AND 1000 AND custkey BETWEEN 500 AND 800");
        assertQuery("SELECT custkey, orderdate, shippriority FROM orders WHERE orderkey BETWEEN 100 AND 1000 AND custkey BETWEEN 500 AND 800");
        assertQuery("SELECT orderkey, orderdate FROM orders WHERE orderdate BETWEEN date '1994-01-01' AND date '1997-03-30'");
        assertQueryUsingH2Cte("SELECT count(*) FROM lineitem_ex WHERE orderkey < 30000 AND ship_by_air = true");
        assertQueryUsingH2Cte("SELECT linenumber, orderkey, ship_by_air, is_returned FROM lineitem_ex WHERE orderkey < 30000 AND ship_by_air = true");
        assertQueryUsingH2Cte("SELECT linenumber, ship_by_air, is_returned FROM lineitem_ex WHERE orderkey < 30000 AND ship_by_air = true");
    }

    @Test
    public void testTimestamps() {
        assertQueryUsingH2Cte("SELECT ship_timestamp FROM lineitem_ex");
        assertQueryUsingH2Cte("SELECT ship_timestamp FROM lineitem_ex WHERE ship_timestamp < TIMESTAMP '1993-01-01 01:00:00'");
        assertQueryUsingH2Cte("SELECT count(*) FROM lineitem_ex WHERE ship_timestamp IS NOT NULL");
        assertQueryUsingH2Cte("SELECT count(*) FROM lineitem_ex WHERE ship_timestamp = TIMESTAMP '2012-08-08 01:00:00'");
        assertQueryUsingH2Cte("SELECT commit_timestamp, ship_timestamp FROM lineitem_ex");
        assertQueryUsingH2Cte("SELECT commit_timestamp, ship_timestamp FROM lineitem_ex WHERE ship_timestamp > TIMESTAMP '1993-08-08 01:00:00' AND commit_timestamp < TIMESTAMP '1993-08-08 01:00:00'");
        assertQueryReturnsEmptyResult("SELECT commit_timestamp, ship_timestamp FROM lineitem_ex WHERE year(ship_timestamp) - year(commit_timestamp) > 1");
        assertQueryUsingH2Cte("SELECT commit_timestamp, ship_timestamp, orderkey FROM lineitem_ex WHERE year(commit_timestamp) > 1993 and year(ship_timestamp) > 1993 and year(ship_timestamp) - year(commit_timestamp) = 1");
    }

    @Test
    public void testArrays() {
        assertQueryUsingH2Cte("SELECT * FROM lineitem_ex");
        assertFilterProject("keys IS NULL", "orderkey, flags");
        assertFilterProject("nested_keys IS NULL", "keys, flags");
        assertFilterProject("flags IS NOT NULL", "keys, orderkey");
        assertFilterProject("nested_keys IS NOT NULL", "keys, flags");
        assertFilterProject("nested_keys[3] IS NULL", "keys, flags");
        assertFilterProject("nested_keys[3] IS NOT NULL", "keys, flags");
        assertQueryUsingH2Cte("SELECT * FROM lineitem_ex WHERE orderkey = 1");
        assertQueryUsingH2Cte("SELECT * FROM lineitem_ex WHERE orderkey % 3 = 1");
        assertFilterProject("keys[2] = 1", "orderkey, flags");
        assertFilterProject("nested_keys[1][2] = 1", "orderkey, flags");
        assertFilterProject("keys[2] % 3 = 1", "orderkey, flags");
        assertFilterProject("nested_keys[1][2] % 3 = 1", "orderkey, flags");
        assertFilterProject("keys[1] < 1000", "orderkey, flags");
        assertFilterProject("nested_keys[1][1] < 1000", "orderkey, flags");
        assertFilterProject("keys[1] < 1000 AND keys[2] % 3 = 1", "orderkey, flags");
        assertFilterProject("nested_keys[1][1] < 1000 AND nested_keys[1][2] % 3 = 1", "orderkey, flags");
        assertFilterProject("keys[1] % 3 = 1 AND (orderkey + keys[2]) % 5 = 1", "orderkey, flags");
        assertFilterProject("nested_keys[1][1] % 3 = 1 AND (orderkey + nested_keys[1][2]) % 5 = 1", "orderkey, flags");
        assertFilterProject("keys[1] < 1000 AND flags[2] = true AND keys[2] % 2 = if(flags[1], 0, 1)", "orderkey, flags");
        assertFilterProject("nested_keys[1][1] < 1000 AND flags[2] = true AND nested_keys[1][2] % 2 = if(flags[1], 0, 1)", "orderkey, flags");
        assertFilterProject("nested_keys IS NOT NULL AND nested_keys[1][1] > 0", "keys");
        assertFilterProject("nested_keys[3] IS NULL AND nested_keys[2][1] > 10", "keys, flags");
        assertFilterProject("nested_keys[3] IS NOT NULL AND nested_keys[1][2] > 10", "keys, flags");
        assertFilterProject("nested_keys IS NOT NULL AND nested_keys[3] IS NOT NULL AND nested_keys[1][1] > 0", "keys");
        assertFilterProject("nested_keys IS NOT NULL AND nested_keys[3] IS NULL AND nested_keys[1][1] > 0", "keys");
        assertFilterProjectFails("keys[5] > 0", "orderkey", "Array subscript out of bounds");
        assertFilterProjectFails("nested_keys[5][1] > 0", "orderkey", "Array subscript out of bounds");
        assertFilterProjectFails("nested_keys[1][5] > 0", "orderkey", "Array subscript out of bounds");
        assertFilterProjectFails("nested_keys[2][5] > 0", "orderkey", "Array subscript out of bounds");
    }

    private void assertFilterProject(String str, String str2) {
        assertQueryUsingH2Cte(String.format("SELECT * FROM lineitem_ex WHERE %s", str));
        assertQueryUsingH2Cte(String.format("SELECT %s FROM lineitem_ex WHERE %s", str2, str));
    }

    private void assertFilterProjectFails(String str, String str2, String str3) {
        assertQueryFails(String.format("SELECT * FROM lineitem_ex WHERE %s", str), str3);
        assertQueryFails(String.format("SELECT %s FROM lineitem_ex WHERE %s", str2, str), str3);
    }

    @Test
    public void testFilterFunctions() {
        assertQuery("SELECT custkey, orderkey, orderdate FROM orders WHERE orderkey % 5 = 0");
        assertQuery("SELECT custkey, orderdate FROM orders WHERE orderkey % 5 = 0");
        assertQuery("SELECT custkey, orderdate FROM orders WHERE orderkey % 5 = 0 AND orderkey > 100");
        assertQuery("SELECT custkey, orderdate FROM orders WHERE orderkey % 5 = 0 AND custkey % 7 = 0");
        assertQuery("SELECT custkey, orderdate FROM orders WHERE (orderkey + custkey) % 5 = 0");
        assertQueryFails("SELECT custkey, orderdate FROM orders WHERE array[1, 2, 3][orderkey % 5 + custkey % 7 + 1] > 0", "Array subscript out of bounds");
        assertQuery("SELECT custkey, orderdate FROM orders WHERE array[1, 2, 3][orderkey % 5 + custkey % 7 + 1] > 0 AND orderkey % 5 = 1 AND custkey % 7 = 0", "SELECT custkey, orderdate FROM orders WHERE orderkey % 5 = 1 AND custkey % 7 = 0");
        assertFilterProject("if(is_returned, linenumber, orderkey) % 5 = 0", "linenumber");
        assertFilterProject("keys[1] % 5 = 0", "orderkey");
        assertFilterProject("nested_keys[1][1] % 5 = 0", "orderkey");
        assertFilterProject("keys[1] % 5 = 0 AND keys[2] > 100", "orderkey");
        assertFilterProject("keys[1] % 5 = 0 AND nested_keys[1][2] > 100", "orderkey");
        assertFilterProject("keys[1] % 5 = 0 AND keys[2] % 7 = 0", "orderkey");
        assertFilterProject("keys[1] % 5 = 0 AND nested_keys[1][2] % 7 = 0", "orderkey");
        assertFilterProject("(cast(keys[1] as integer) + keys[3]) % 5 = 0", "orderkey");
        assertFilterProject("(cast(keys[1] as integer) + nested_keys[1][2]) % 5 = 0", "orderkey");
        assertQueryFails("SELECT orderkey FROM lineitem_ex WHERE keys[5] % 7 = 0", "Array subscript out of bounds");
        assertQueryFails("SELECT orderkey FROM lineitem_ex WHERE nested_keys[1][5] % 7 = 0", "Array subscript out of bounds");
        assertQueryFails("SELECT * FROM lineitem_ex WHERE nested_keys[1][5] > 0", "Array subscript out of bounds");
        assertQueryFails("SELECT orderkey FROM lineitem_ex WHERE nested_keys[1][5] > 0", "Array subscript out of bounds");
        assertQueryFails("SELECT * FROM lineitem_ex WHERE nested_keys[1][5] > 0 AND orderkey % 5 = 0", "Array subscript out of bounds");
        assertFilterProject("nested_keys[1][5] > 0 AND orderkey % 5 > 10", "keys");
    }

    @Test
    public void testPartitionColumns() {
        assertUpdate("CREATE TABLE test_partition_columns WITH (partitioned_by = ARRAY['p']) AS\nSELECT * FROM (VALUES (1, 'abc'), (2, 'abc')) as t(x, p)", 2L);
        assertQuery("SELECT * FROM test_partition_columns", "SELECT 1, 'abc' UNION ALL SELECT 2, 'abc'");
        assertQuery("SELECT * FROM test_partition_columns WHERE p = 'abc'", "SELECT 1, 'abc' UNION ALL SELECT 2, 'abc'");
        assertQuery("SELECT * FROM test_partition_columns WHERE p LIKE 'a%'", "SELECT 1, 'abc' UNION ALL SELECT 2, 'abc'");
        assertQuery("SELECT * FROM test_partition_columns WHERE substr(p, x, 1) = 'a'", "SELECT 1, 'abc'");
        assertQueryReturnsEmptyResult("SELECT * FROM test_partition_columns WHERE p = 'xxx'");
        assertUpdate("DROP TABLE test_partition_columns");
    }

    @Test
    public void testBucketColumn() {
        getQueryRunner().execute("CREATE TABLE test_bucket_column WITH (bucketed_by = ARRAY['orderkey'], bucket_count = 11) AS SELECT linenumber, orderkey FROM lineitem");
        assertQuery("SELECT linenumber, \"$bucket\" FROM test_bucket_column", "SELECT linenumber, orderkey % 11 FROM lineitem");
        assertQuery("SELECT linenumber, \"$bucket\" FROM test_bucket_column WHERE (\"$bucket\" + linenumber) % 2 = 1", "SELECT linenumber, orderkey % 11 FROM lineitem WHERE (orderkey % 11 + linenumber) % 2 = 1");
        assertUpdate("DROP TABLE test_bucket_column");
    }

    @Test
    public void testPathColumn() {
        Session defaultSession = getQueryRunner().getDefaultSession();
        assertQuerySucceeds(defaultSession, "SELECT linenumber, \"$path\" FROM lineitem");
        assertQuerySucceeds(defaultSession, "SELECT linenumber, \"$path\" FROM lineitem WHERE length(\"$path\") % 2 = linenumber % 2");
    }

    private void assertQueryUsingH2Cte(String str) {
        assertQuery(str, WITH_LINEITEM_EX + toH2(str));
    }

    private static String toH2(String str) {
        return replaceArraySubscripts(str).replaceAll(" if\\(", " casewhen(");
    }

    private static String replaceArraySubscripts(String str) {
        Matcher matcher = ARRAY_SUBSCRIPT_PATTERN.matcher(str);
        StringBuilder sb = new StringBuilder();
        int i = 0;
        while (true) {
            int i2 = i;
            if (!matcher.find()) {
                sb.append(str.substring(i2));
                return sb.toString();
            }
            String group = matcher.group(1);
            List splitToList = Splitter.onPattern("[^0-9]").omitEmptyStrings().splitToList(matcher.group(2));
            for (int i3 = 0; i3 < splitToList.size(); i3++) {
                group = String.format("array_get(%s, %s)", group, splitToList.get(i3));
            }
            sb.append((CharSequence) str, i2, matcher.start()).append(group);
            i = matcher.end();
        }
    }

    private static Session noPushdownFilter(Session session) {
        return Session.builder(session).setCatalogSessionProperty(HiveQueryRunner.HIVE_CATALOG, "pushdown_filter_enabled", "false").build();
    }
}
