package org.apache.iceberg.spark.source;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
import org.apache.iceberg.hadoop.HadoopTables;
import org.apache.iceberg.mapping.MappingUtil;
import org.apache.iceberg.mapping.NameMappingParser;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.SparkTableUtil;
import org.apache.iceberg.types.Conversions;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.TableIdentifier;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.internal.SQLConf;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/iceberg/spark/source/TestSparkTableUtilWithInMemoryCatalog.class */
public class TestSparkTableUtilWithInMemoryCatalog {
    private static final HadoopTables TABLES = new HadoopTables(new Configuration());
    private static final Schema SCHEMA = new Schema(new Types.NestedField[]{Types.NestedField.optional(1, "id", Types.IntegerType.get()), Types.NestedField.optional(2, "data", Types.StringType.get())});
    private static final PartitionSpec SPEC = PartitionSpec.builderFor(SCHEMA).identity("data").build();
    private static SparkSession spark;

    @Rule
    public TemporaryFolder temp = new TemporaryFolder();
    private String tableLocation = null;

    @BeforeClass
    public static void startSpark() {
        spark = SparkSession.builder().master("local[2]").getOrCreate();
    }

    @AfterClass
    public static void stopSpark() {
        SparkSession sparkSession = spark;
        spark = null;
        sparkSession.stop();
    }

    @Before
    public void setupTableLocation() throws Exception {
        this.tableLocation = this.temp.newFolder().toURI().toString();
    }

    @Test
    public void testImportUnpartitionedTable() throws IOException {
        HashMap newHashMap = Maps.newHashMap();
        newHashMap.put("write.metadata.metrics.default", "none");
        newHashMap.put("write.metadata.metrics.column.data", "full");
        Table create = TABLES.create(SCHEMA, PartitionSpec.unpartitioned(), newHashMap, this.tableLocation);
        ArrayList newArrayList = Lists.newArrayList(new SimpleRecord[]{new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")});
        try {
            spark.createDataFrame(newArrayList, SimpleRecord.class).coalesce(1).select("id", new String[]{"data"}).write().format("parquet").mode("append").option("path", this.temp.newFolder("parquet_table").toURI().toString()).saveAsTable("parquet_table");
            SparkTableUtil.importSparkTable(spark, new TableIdentifier("parquet_table"), create, this.temp.newFolder("staging-dir").toString());
            Assert.assertEquals("Result rows should match", newArrayList, spark.read().format("iceberg").load(this.tableLocation).orderBy("id", new String[0]).as(Encoders.bean(SimpleRecord.class)).collectAsList());
            Dataset<Row> load = spark.read().format("iceberg").load(this.tableLocation + "#files");
            checkFieldMetrics(load, create.schema().findField("id"), true);
            checkFieldMetrics(load, create.schema().findField("data"), false);
            spark.sql("DROP TABLE parquet_table");
        } catch (Throwable th) {
            spark.sql("DROP TABLE parquet_table");
            throw th;
        }
    }

    @Test
    public void testImportPartitionedTable() throws IOException {
        HashMap newHashMap = Maps.newHashMap();
        newHashMap.put("write.metadata.metrics.default", "none");
        newHashMap.put("write.metadata.metrics.column.data", "full");
        Table create = TABLES.create(SCHEMA, SPEC, newHashMap, this.tableLocation);
        ArrayList newArrayList = Lists.newArrayList(new SimpleRecord[]{new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")});
        try {
            spark.createDataFrame(newArrayList, SimpleRecord.class).select("id", new String[]{"data"}).write().format("parquet").mode("append").option("path", this.temp.newFolder("parquet_table").toURI().toString()).partitionBy(new String[]{"data"}).saveAsTable("parquet_table");
            Assert.assertEquals("Should have 3 partitions", 3L, SparkTableUtil.getPartitions(spark, "parquet_table").size());
            Assert.assertEquals("Should have 1 partition where data = 'a'", 1L, SparkTableUtil.getPartitionsByFilter(spark, "parquet_table", "data = 'a'").size());
            SparkTableUtil.importSparkTable(spark, new TableIdentifier("parquet_table"), create, this.temp.newFolder("staging-dir").toString());
            Assert.assertEquals("Result rows should match", newArrayList, spark.read().format("iceberg").load(this.tableLocation).orderBy("id", new String[0]).as(Encoders.bean(SimpleRecord.class)).collectAsList());
            Dataset<Row> load = spark.read().format("iceberg").load(this.tableLocation + "#files");
            checkFieldMetrics(load, create.schema().findField("id"), true);
            checkFieldMetrics(load, create.schema().findField("data"), true);
            spark.sql("DROP TABLE parquet_table");
        } catch (Throwable th) {
            spark.sql("DROP TABLE parquet_table");
            throw th;
        }
    }

    @Test
    public void testImportPartitions() throws IOException {
        Table create = TABLES.create(SCHEMA, SPEC, this.tableLocation);
        ArrayList newArrayList = Lists.newArrayList(new SimpleRecord[]{new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")});
        try {
            spark.createDataFrame(newArrayList, SimpleRecord.class).select("id", new String[]{"data"}).write().format("parquet").mode("append").option("path", this.temp.newFolder("parquet_table").toURI().toString()).partitionBy(new String[]{"data"}).saveAsTable("parquet_table");
            SparkTableUtil.importSparkPartitions(spark, SparkTableUtil.getPartitionsByFilter(spark, "parquet_table", "data = 'a'"), create, create.spec(), this.temp.newFolder("staging-dir").toString());
            Assert.assertEquals("Result rows should match", Lists.newArrayList(new SimpleRecord[]{new SimpleRecord(1, "a")}), spark.read().format("iceberg").load(this.tableLocation).orderBy("id", new String[0]).as(Encoders.bean(SimpleRecord.class)).collectAsList());
            spark.sql("DROP TABLE parquet_table");
        } catch (Throwable th) {
            spark.sql("DROP TABLE parquet_table");
            throw th;
        }
    }

    @Test
    public void testImportPartitionsWithSnapshotInheritance() throws IOException {
        Table create = TABLES.create(SCHEMA, SPEC, this.tableLocation);
        create.updateProperties().set("compatibility.snapshot-id-inheritance.enabled", "true").commit();
        ArrayList newArrayList = Lists.newArrayList(new SimpleRecord[]{new SimpleRecord(1, "a"), new SimpleRecord(2, "b"), new SimpleRecord(3, "c")});
        try {
            spark.createDataFrame(newArrayList, SimpleRecord.class).select("id", new String[]{"data"}).write().format("parquet").mode("append").option("path", this.temp.newFolder("parquet_table").toURI().toString()).partitionBy(new String[]{"data"}).saveAsTable("parquet_table");
            SparkTableUtil.importSparkPartitions(spark, SparkTableUtil.getPartitionsByFilter(spark, "parquet_table", "data = 'a'"), create, create.spec(), this.temp.newFolder("staging-dir").toString());
            Assert.assertEquals("Result rows should match", Lists.newArrayList(new SimpleRecord[]{new SimpleRecord(1, "a")}), spark.read().format("iceberg").load(this.tableLocation).orderBy("id", new String[0]).as(Encoders.bean(SimpleRecord.class)).collectAsList());
            spark.sql("DROP TABLE parquet_table");
        } catch (Throwable th) {
            spark.sql("DROP TABLE parquet_table");
            throw th;
        }
    }

    @Test
    public void testImportTableWithMappingForNestedData() throws IOException {
        try {
            spark.range(1L, 2L).withColumn("extra_col", functions.lit(-1)).withColumn("struct", functions.expr("named_struct('nested_1', 'a', 'nested_2', 'd', 'nested_3', 'f')")).union(spark.range(2L, 3L).withColumn("extra_col", functions.lit(-1)).withColumn("struct", functions.expr("named_struct('nested_1', 'b', 'nested_2', 'e', 'nested_3', 'g')"))).coalesce(1).select("id", new String[]{"extra_col", "struct"}).write().format("parquet").mode("append").option("path", this.temp.newFolder("parquet_table").toURI().toString()).saveAsTable("parquet_table");
            Schema schema = new Schema(new Types.NestedField[]{Types.NestedField.optional(1, "id", Types.LongType.get()), Types.NestedField.required(2, "struct", Types.StructType.of(new Types.NestedField[]{Types.NestedField.required(3, "nested_1", Types.StringType.get()), Types.NestedField.required(4, "nested_3", Types.StringType.get())}))});
            Table create = TABLES.create(schema, PartitionSpec.unpartitioned(), this.tableLocation);
            create.updateProperties().set("write.metadata.metrics.default", "counts").set("write.metadata.metrics.column.id", "full").set("write.metadata.metrics.column.struct.nested_3", "full").set("schema.name-mapping.default", NameMappingParser.toJson(MappingUtil.create(schema))).commit();
            SparkTableUtil.importSparkTable(spark, new TableIdentifier("parquet_table"), create, this.temp.newFolder("staging-dir").toString());
            Assert.assertEquals("Rows must match", spark.table("parquet_table").select("id", new String[]{"struct.nested_1", "struct.nested_3"}).collectAsList(), spark.read().format("iceberg").load(this.tableLocation).select("id", new String[]{"struct.nested_1", "struct.nested_3"}).collectAsList());
            Dataset<Row> load = spark.read().format("iceberg").load(this.tableLocation + "#files");
            List collectAsList = load.select("lower_bounds", new String[]{"upper_bounds"}).collectAsList();
            Assert.assertEquals("Must have lower bounds for 2 columns", 2L, ((Row) collectAsList.get(0)).getMap(0).size());
            Assert.assertEquals("Must have upper bounds for 2 columns", 2L, ((Row) collectAsList.get(0)).getMap(1).size());
            checkFieldMetrics(load, create.schema().findField("struct.nested_1"), true);
            checkFieldMetrics(load, create.schema().findField("id"), 1L, 2L);
            checkFieldMetrics(load, create.schema().findField("struct.nested_3"), "f", "g");
            spark.sql("DROP TABLE parquet_table");
        } catch (Throwable th) {
            spark.sql("DROP TABLE parquet_table");
            throw th;
        }
    }

    @Test
    public void testImportTableWithMappingForNestedDataPartitionedTable() throws IOException {
        try {
            spark.range(1L, 2L).withColumn("extra_col", functions.lit(-1)).withColumn("struct", functions.expr("named_struct('nested_1', 'a', 'nested_2', 'd', 'nested_3', 'f')")).withColumn("data", functions.lit("Z")).union(spark.range(2L, 3L).withColumn("extra_col", functions.lit(-1)).withColumn("struct", functions.expr("named_struct('nested_1', 'b', 'nested_2', 'e', 'nested_3', 'g')")).withColumn("data", functions.lit("Z"))).coalesce(1).select("id", new String[]{"extra_col", "struct", "data"}).write().format("parquet").mode("append").option("path", this.temp.newFolder("parquet_table").toURI().toString()).partitionBy(new String[]{"data"}).saveAsTable("parquet_table");
            Schema schema = new Schema(new Types.NestedField[]{Types.NestedField.optional(1, "id", Types.LongType.get()), Types.NestedField.required(2, "struct", Types.StructType.of(new Types.NestedField[]{Types.NestedField.required(4, "nested_1", Types.StringType.get()), Types.NestedField.required(5, "nested_3", Types.StringType.get())})), Types.NestedField.required(3, "data", Types.StringType.get())});
            Table create = TABLES.create(schema, PartitionSpec.builderFor(schema).identity("data").build(), this.tableLocation);
            create.updateProperties().set("write.metadata.metrics.default", "counts").set("write.metadata.metrics.column.id", "full").set("write.metadata.metrics.column.struct.nested_3", "full").set("schema.name-mapping.default", NameMappingParser.toJson(MappingUtil.create(schema))).commit();
            SparkTableUtil.importSparkTable(spark, new TableIdentifier("parquet_table"), create, this.temp.newFolder("staging-dir").toString());
            Assert.assertEquals("Rows must match", spark.table("parquet_table").select("id", new String[]{"struct.nested_1", "struct.nested_3", "data"}).collectAsList(), spark.read().format("iceberg").load(this.tableLocation).select("id", new String[]{"struct.nested_1", "struct.nested_3", "data"}).collectAsList());
            Dataset<Row> load = spark.read().format("iceberg").load(this.tableLocation + "#files");
            List collectAsList = load.select("lower_bounds", new String[]{"upper_bounds"}).collectAsList();
            Assert.assertEquals("Must have lower bounds for 2 columns", 2L, ((Row) collectAsList.get(0)).getMap(0).size());
            Assert.assertEquals("Must have upper bounds for 2 columns", 2L, ((Row) collectAsList.get(0)).getMap(1).size());
            checkFieldMetrics(load, create.schema().findField("struct.nested_1"), true);
            checkFieldMetrics(load, create.schema().findField("id"), 1L, 2L);
            checkFieldMetrics(load, create.schema().findField("struct.nested_3"), "f", "g");
            spark.sql("DROP TABLE parquet_table");
        } catch (Throwable th) {
            spark.sql("DROP TABLE parquet_table");
            throw th;
        }
    }

    @Test
    public void testImportTableWithInt96Timestamp() throws IOException {
        String uri = this.temp.newFolder("parquet_table").toURI().toString();
        try {
            spark.conf().set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE().key(), "INT96");
            spark.range(1L, 10L).withColumn("tmp_col", functions.to_timestamp(functions.lit("2010-03-20 10:40:30.1234"))).coalesce(1).select("id", new String[]{"tmp_col"}).write().format("parquet").mode("append").option("path", uri).saveAsTable("parquet_table");
            Table create = TABLES.create(new Schema(new Types.NestedField[]{Types.NestedField.optional(1, "id", Types.LongType.get()), Types.NestedField.optional(2, "tmp_col", Types.TimestampType.withZone())}), PartitionSpec.unpartitioned(), this.tableLocation);
            create.updateProperties().set("write.metadata.metrics.default", "full").set("read.parquet.vectorization.enabled", "false").commit();
            SparkTableUtil.importSparkTable(spark, new TableIdentifier("parquet_table"), create, this.temp.newFolder("staging-dir").toString());
            Assert.assertEquals("Rows must match", spark.table("parquet_table").select("id", new String[]{"tmp_col"}).collectAsList(), spark.read().format("iceberg").load(this.tableLocation).select("id", new String[]{"tmp_col"}).collectAsList());
            Dataset<Row> load = spark.read().format("iceberg").load(this.tableLocation + "#files");
            checkFieldMetrics(load, create.schema().findField("tmp_col"), true);
            checkFieldMetrics(load, create.schema().findField("id"), 1L, 9L);
            spark.sql("DROP TABLE parquet_table");
        } catch (Throwable th) {
            spark.sql("DROP TABLE parquet_table");
            throw th;
        }
    }

    private void checkFieldMetrics(Dataset<Row> dataset, Types.NestedField nestedField, Object obj, Object obj2) {
        List collectAsList = dataset.selectExpr(new String[]{String.format("lower_bounds['%d']", Integer.valueOf(nestedField.fieldId())), String.format("upper_bounds['%d']", Integer.valueOf(nestedField.fieldId()))}).collectAsList();
        Assert.assertEquals("Min value should match", obj.toString(), Conversions.fromByteBuffer(nestedField.type(), ByteBuffer.wrap((byte[]) ((Row) collectAsList.get(0)).getAs(0))).toString());
        Assert.assertEquals("Max value should match", obj2.toString(), Conversions.fromByteBuffer(nestedField.type(), ByteBuffer.wrap((byte[]) ((Row) collectAsList.get(0)).getAs(1))).toString());
    }

    private void checkFieldMetrics(Dataset<Row> dataset, Types.NestedField nestedField, boolean z) {
        dataset.selectExpr(new String[]{String.format("lower_bounds['%d']", Integer.valueOf(nestedField.fieldId())), String.format("upper_bounds['%d']", Integer.valueOf(nestedField.fieldId()))}).collectAsList().forEach(row -> {
            Assert.assertEquals("Invalid metrics for column: " + nestedField.name(), Boolean.valueOf(z), Boolean.valueOf(row.isNullAt(0)));
            Assert.assertEquals("Invalid metrics for column: " + nestedField.name(), Boolean.valueOf(z), Boolean.valueOf(row.isNullAt(1)));
        });
    }
}
