package io.basestar.spark;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.basestar.spark.BucketFunction;
import io.basestar.util.Nullsafe;
import java.io.IOException;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.net.URI;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.catalog.CatalogTable;
import org.apache.spark.sql.catalyst.catalog.CatalogTablePartition;
import org.apache.spark.sql.catalyst.catalog.ExternalCatalog;
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;

/* loaded from: input_file:io/basestar/spark/PartitionedUpsertSink.class */
public class PartitionedUpsertSink extends PartitionedUpsert implements Sink<Dataset<Row>> {
    private static final Logger log = LoggerFactory.getLogger(PartitionedUpsertSink.class);
    private static final Integer DEFAULT_MERGE_COUNT = 100;
    private static final String STATE_COLUMN = "__state";
    private static final String CREATE_STATE = "CREATE";
    private static final String UPDATE_STATE = "UPDATE";
    private static final String DELETE_STATE = "DELETE";
    private static final String IGNORE_STATE = "IGNORE";
    private final String databaseName;
    private final String tableName;
    private final List<String> idColumns;
    private final String upsertId;
    private final Format format;
    private final String deletedColumn;
    private final int mergeCount;

    /* loaded from: input_file:io/basestar/spark/PartitionedUpsertSink$Builder.class */
    public static class Builder {
        private String databaseName;
        private String tableName;
        private List<String> idColumns;
        private String upsertId;
        private Format format;
        private String deletedColumn;
        private Integer mergeCount;

        Builder() {
        }

        public Builder databaseName(String str) {
            this.databaseName = str;
            return this;
        }

        public Builder tableName(String str) {
            this.tableName = str;
            return this;
        }

        public Builder idColumns(List<String> list) {
            this.idColumns = list;
            return this;
        }

        public Builder upsertId(String str) {
            this.upsertId = str;
            return this;
        }

        public Builder format(Format format) {
            this.format = format;
            return this;
        }

        public Builder deletedColumn(String str) {
            this.deletedColumn = str;
            return this;
        }

        public Builder mergeCount(Integer num) {
            this.mergeCount = num;
            return this;
        }

        public PartitionedUpsertSink build() {
            return new PartitionedUpsertSink(this.databaseName, this.tableName, this.idColumns, this.upsertId, this.format, this.deletedColumn, this.mergeCount);
        }

        public String toString() {
            return "PartitionedUpsertSink.Builder(databaseName=" + this.databaseName + ", tableName=" + this.tableName + ", idColumns=" + this.idColumns + ", upsertId=" + this.upsertId + ", format=" + this.format + ", deletedColumn=" + this.deletedColumn + ", mergeCount=" + this.mergeCount + ")";
        }
    }

    /* loaded from: input_file:io/basestar/spark/PartitionedUpsertSink$Partition.class */
    public static class Partition implements Serializable {
        private String[] values;

        public Map<String, String> spec(List<String> list) {
            HashMap hashMap = new HashMap();
            for (int i = 0; i != list.size(); i++) {
                hashMap.put(list.get(i), this.values[i]);
            }
            return hashMap;
        }

        public boolean equals(Object obj) {
            if (obj instanceof Partition) {
                return Arrays.deepEquals(this.values, ((Partition) obj).values);
            }
            return false;
        }

        public int hashCode() {
            return Arrays.deepHashCode(this.values);
        }

        public String toString() {
            return String.join(",", this.values);
        }

        public String buildPath(List<String> list) {
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i != list.size(); i++) {
                sb.append(list.get(i)).append("=").append(this.values[i]).append("/");
            }
            return sb.toString();
        }

        public String[] getValues() {
            return this.values;
        }

        public void setValues(String[] strArr) {
            this.values = strArr;
        }

        public Partition() {
        }

        public Partition(String[] strArr) {
            this.values = strArr;
        }
    }

    /* loaded from: input_file:io/basestar/spark/PartitionedUpsertSink$PartitionState.class */
    public static class PartitionState implements Serializable {
        private Partition partition;
        private String state;

        public String toString() {
            return this.partition + "=" + this.state;
        }

        public Partition getPartition() {
            return this.partition;
        }

        public String getState() {
            return this.state;
        }

        public void setPartition(Partition partition) {
            this.partition = partition;
        }

        public void setState(String str) {
            this.state = str;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof PartitionState)) {
                return false;
            }
            PartitionState partitionState = (PartitionState) obj;
            if (!partitionState.canEqual(this)) {
                return false;
            }
            Partition partition = getPartition();
            Partition partition2 = partitionState.getPartition();
            if (partition == null) {
                if (partition2 != null) {
                    return false;
                }
            } else if (!partition.equals(partition2)) {
                return false;
            }
            String state = getState();
            String state2 = partitionState.getState();
            return state == null ? state2 == null : state.equals(state2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof PartitionState;
        }

        public int hashCode() {
            Partition partition = getPartition();
            int hashCode = (1 * 59) + (partition == null ? 43 : partition.hashCode());
            String state = getState();
            return (hashCode * 59) + (state == null ? 43 : state.hashCode());
        }

        public PartitionState() {
        }

        public PartitionState(Partition partition, String str) {
            this.partition = partition;
            this.state = str;
        }
    }

    PartitionedUpsertSink(String str, String str2, List<String> list, String str3, Format format, String str4, Integer num) {
        this.databaseName = str;
        this.tableName = str2;
        this.idColumns = (List) Nullsafe.option(list, ImmutableList.of("id"));
        this.upsertId = (String) Nullsafe.option(str3, PartitionedUpsert::defaultUpsertId);
        this.format = (Format) Nullsafe.option(format, Format.PARQUET);
        this.deletedColumn = (String) Nullsafe.option(str4, "__deleted");
        this.mergeCount = ((Integer) Nullsafe.option(num, DEFAULT_MERGE_COUNT)).intValue();
    }

    private Dataset<Row> clean(Dataset<Row> dataset) {
        Dataset<Row> dataset2 = dataset;
        List asList = Arrays.asList(dataset.schema().fieldNames());
        if (!asList.contains(PartitionedUpsert.UPSERT_PARTITION)) {
            dataset2 = dataset2.drop(PartitionedUpsert.UPSERT_PARTITION);
        }
        return asList.contains(this.deletedColumn) ? dataset2.withColumn(STATE_COLUMN, functions.when(dataset2.col(this.deletedColumn), DELETE_STATE).otherwise(UPDATE_STATE)).drop(this.deletedColumn) : dataset2.withColumn(STATE_COLUMN, functions.lit(UPDATE_STATE));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Partition partition(Row row, int[] iArr) {
        String[] strArr = new String[iArr.length];
        for (int i = 0; i != iArr.length; i++) {
            strArr[i] = row.getString(iArr[i]);
        }
        return new Partition(strArr);
    }

    private static Map<Partition, CatalogTablePartition> existing(ExternalCatalog externalCatalog, String str, String str2, List<String> list, Dataset<Row> dataset) {
        StructType schema = dataset.schema();
        Stream<String> stream = list.stream();
        schema.getClass();
        int[] array = stream.mapToInt(schema::fieldIndex).toArray();
        List collectAsList = dataset.map(row -> {
            return partition(row, array);
        }, Encoders.bean(Partition.class)).distinct().collectAsList();
        HashMap hashMap = new HashMap();
        collectAsList.forEach(partition -> {
            Option partitionOption = externalCatalog.getPartitionOption(str, str2, ScalaUtils.asScalaMap(partition.spec(list)));
            if (partitionOption.isDefined()) {
                hashMap.put(partition, partitionOption.get());
            }
        });
        return hashMap;
    }

    private static Map<Partition, Set<String>> states(List<String> list, Dataset<Row> dataset) {
        StructType schema = dataset.schema();
        int fieldIndex = schema.fieldIndex(STATE_COLUMN);
        Stream<String> stream = list.stream();
        schema.getClass();
        int[] array = stream.mapToInt(schema::fieldIndex).toArray();
        return (Map) dataset.map(row -> {
            return new PartitionState(partition(row, array), row.getString(fieldIndex));
        }, Encoders.bean(PartitionState.class)).distinct().collectAsList().stream().collect(Collectors.groupingBy((v0) -> {
            return v0.getPartition();
        }, Collectors.mapping((v0) -> {
            return v0.getState();
        }, Collectors.toSet())));
    }

    private static Dataset<Row> load(SQLContext sQLContext, URI uri, Format format, StructType structType, Collection<CatalogTablePartition> collection) {
        return collection.isEmpty() ? sQLContext.createDataFrame(ImmutableList.of(), structType) : sQLContext.read().format(format.getSparkFormat()).option("basePath", uri.toString()).schema(structType).load((String[]) collection.stream().map(catalogTablePartition -> {
            return catalogTablePartition.location().toString();
        }).toArray(i -> {
            return new String[i];
        })).withColumn(STATE_COLUMN, functions.lit(IGNORE_STATE)).map(row -> {
            return SparkSchemaUtils.conform(row, structType);
        }, RowEncoder.apply(structType));
    }

    private static Dataset<Row> join(List<String> list, Dataset<Row> dataset, Dataset<Row> dataset2) {
        Column column = (Column) list.stream().map(str -> {
            return dataset2.col(str).equalTo(dataset.col(str));
        }).reduce((v0, v1) -> {
            return v0.and(v1);
        }).orElseThrow(IllegalStateException::new);
        StructType schema = dataset2.schema();
        return dataset2.joinWith(dataset, column, "full_outer").map(tuple2 -> {
            return tuple2._1() != null ? tuple2._2() != null ? (Row) tuple2._1() : SparkSchemaUtils.with((Row) tuple2._1(), ImmutableMap.of(STATE_COLUMN, CREATE_STATE)) : SparkSchemaUtils.conform((Row) tuple2._2(), schema);
        }, RowEncoder.apply(dataset2.schema()));
    }

    private Dataset<Row> output(List<String> list, Dataset<Row> dataset, Map<Partition, String> map) {
        StructType schema = dataset.schema();
        StructField[] fields = schema.fields();
        StructField[] structFieldArr = new StructField[fields.length + 1];
        for (int i = 0; i != fields.length; i++) {
            structFieldArr[i] = fields[i];
        }
        structFieldArr[fields.length] = SparkSchemaUtils.field(PartitionedUpsert.UPSERT_PARTITION, DataTypes.StringType);
        StructType createStructType = DataTypes.createStructType(structFieldArr);
        Stream<String> stream = list.stream();
        schema.getClass();
        int[] array = stream.mapToInt(schema::fieldIndex).toArray();
        return dataset.map(row -> {
            Partition partition = partition(row, array);
            Object[] objArr = new Object[structFieldArr.length];
            for (int i2 = 0; i2 != fields.length; i2++) {
                objArr[i2] = row.get(i2);
            }
            objArr[fields.length] = map.get(partition);
            return new GenericRow(objArr);
        }, RowEncoder.apply(createStructType));
    }

    private static String joinPaths(String str, String str2) {
        return str.endsWith("/") ? str + str2 : str + "/" + str2;
    }

    @Override // io.basestar.spark.Sink
    public void accept(Dataset<Row> dataset) {
        SparkSession sparkSession = dataset.sparkSession();
        SparkContext sparkContext = sparkSession.sparkContext();
        SQLContext sqlContext = dataset.sqlContext();
        sparkContext.setJobDescription("Upsert to " + this.tableName);
        ExternalCatalogWithListener externalCatalog = sparkSession.sharedState().externalCatalog();
        CatalogTable table = externalCatalog.getTable(this.databaseName, this.tableName);
        List<String> asJavaList = ScalaUtils.asJavaList(table.partitionColumnNames());
        URI location = table.location();
        Dataset cache = clean(dataset).cache();
        Map<Partition, CatalogTablePartition> existing = existing(externalCatalog, this.databaseName, this.tableName, asJavaList, cache);
        log.info("Upsert on {} refers to existing partitions: {}", this.tableName, existing.keySet());
        Dataset<Row> cache2 = join(this.idColumns, load(sqlContext, location, this.format, cache.schema(), existing.values()), cache).cache();
        Map<Partition, Set<String>> states = states(asJavaList, cache2);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        AtomicInteger atomicInteger2 = new AtomicInteger(0);
        AtomicInteger atomicInteger3 = new AtomicInteger(0);
        Configuration hadoopConfiguration = sparkContext.hadoopConfiguration();
        Map<Partition, String> map = (Map) states.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            Set set = (Set) entry.getValue();
            CatalogTablePartition catalogTablePartition = (CatalogTablePartition) existing.get(entry.getKey());
            if (catalogTablePartition == null) {
                atomicInteger.incrementAndGet();
                return this.upsertId;
            }
            if (set.contains(UPDATE_STATE) || set.contains(DELETE_STATE)) {
                atomicInteger2.incrementAndGet();
                return this.upsertId;
            }
            try {
                Path path = new Path(catalogTablePartition.location());
                int length = path.getFileSystem(hadoopConfiguration).listStatus(path).length;
                if (length > this.mergeCount) {
                    log.error("Merging {} because it has {} files", entry.getKey(), Integer.valueOf(length));
                    atomicInteger2.incrementAndGet();
                    return this.upsertId;
                }
            } catch (IOException e) {
                log.error("Failed to check file count for {}", catalogTablePartition.location(), e);
            }
            atomicInteger3.incrementAndGet();
            return extractUpsertId(catalogTablePartition.location());
        }));
        log.info("Upsert on {} has target states: {}", this.tableName, states);
        sparkContext.setJobDescription("Upsert to " + this.tableName + " (create: " + atomicInteger + ", merge: " + atomicInteger2 + ", append: " + atomicInteger3 + ")");
        Dataset<Row> output = output(asJavaList, cache2, map);
        Dataset drop = output.filter(functions.when(output.col(PartitionedUpsert.UPSERT_PARTITION).equalTo(this.upsertId), output.col(STATE_COLUMN).notEqual(DELETE_STATE)).otherwise(output.col(STATE_COLUMN).equalTo(CREATE_STATE))).drop(STATE_COLUMN);
        String[] strArr = (String[]) Stream.concat(asJavaList.stream(), Stream.of(PartitionedUpsert.UPSERT_PARTITION)).toArray(i -> {
            return new String[i];
        });
        Stream stream = Arrays.stream(strArr);
        drop.getClass();
        drop.repartition(1, (Column[]) stream.map(drop::col).toArray(i2 -> {
            return new Column[i2];
        })).write().format(this.format.getSparkFormat()).mode(SaveMode.Append).partitionBy(strArr).save(location.toString());
        map.forEach((partition, str) -> {
            if (this.upsertId.equals(str)) {
                Map<String, String> spec = partition.spec(asJavaList);
                URI create = URI.create(joinPaths(location.toString(), partition.buildPath(asJavaList) + PartitionedUpsert.UPSERT_PARTITION + "=" + str + "/"));
                CatalogTablePartition partition = SparkUtils.partition(spec, this.format, create);
                if (((CatalogTablePartition) existing.get(partition)) == null) {
                    log.info("Creating partition {} with location {}", partition, create);
                    externalCatalog.createPartitions(this.databaseName, this.tableName, Option.apply(partition).toList(), false);
                } else {
                    log.info("Updating partition {} with location {}", partition, create);
                    externalCatalog.alterPartitions(this.databaseName, this.tableName, Option.apply(partition).toList());
                }
            }
        });
        cache2.unpersist(true);
        cache.unpersist(true);
        sparkContext.setJobDescription((String) null);
    }

    public static Builder builder() {
        return new Builder();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1828056515:
                if (implMethodName.equals("lambda$existing$509e993$1")) {
                    z = true;
                    break;
                }
                break;
            case -51952602:
                if (implMethodName.equals("lambda$output$db23585f$1")) {
                    z = false;
                    break;
                }
                break;
            case 100189973:
                if (implMethodName.equals("lambda$join$24fb93d7$1")) {
                    z = 3;
                    break;
                }
                break;
            case 1395808221:
                if (implMethodName.equals("lambda$load$cbfb273b$1")) {
                    z = 4;
                    break;
                }
                break;
            case 1541008613:
                if (implMethodName.equals("lambda$states$3eb6412b$1")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/basestar/spark/PartitionedUpsertSink") && serializedLambda.getImplMethodSignature().equals("([I[Lorg/apache/spark/sql/types/StructField;[Lorg/apache/spark/sql/types/StructField;Ljava/util/Map;Lorg/apache/spark/sql/Row;)Lorg/apache/spark/sql/Row;")) {
                    int[] iArr = (int[]) serializedLambda.getCapturedArg(0);
                    StructField[] structFieldArr = (StructField[]) serializedLambda.getCapturedArg(1);
                    StructField[] structFieldArr2 = (StructField[]) serializedLambda.getCapturedArg(2);
                    Map map = (Map) serializedLambda.getCapturedArg(3);
                    return row -> {
                        Partition partition = partition(row, iArr);
                        Object[] objArr = new Object[structFieldArr.length];
                        for (int i2 = 0; i2 != structFieldArr2.length; i2++) {
                            objArr[i2] = row.get(i2);
                        }
                        objArr[structFieldArr2.length] = map.get(partition);
                        return new GenericRow(objArr);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/basestar/spark/PartitionedUpsertSink") && serializedLambda.getImplMethodSignature().equals("([ILorg/apache/spark/sql/Row;)Lio/basestar/spark/PartitionedUpsertSink$Partition;")) {
                    int[] iArr2 = (int[]) serializedLambda.getCapturedArg(0);
                    return row2 -> {
                        return partition(row2, iArr2);
                    };
                }
                break;
            case BucketFunction.HashPrefix.DEFAULT_LEN /* 2 */:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/basestar/spark/PartitionedUpsertSink") && serializedLambda.getImplMethodSignature().equals("([IILorg/apache/spark/sql/Row;)Lio/basestar/spark/PartitionedUpsertSink$PartitionState;")) {
                    int[] iArr3 = (int[]) serializedLambda.getCapturedArg(0);
                    int intValue = ((Integer) serializedLambda.getCapturedArg(1)).intValue();
                    return row3 -> {
                        return new PartitionState(partition(row3, iArr3), row3.getString(intValue));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/basestar/spark/PartitionedUpsertSink") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/spark/sql/types/StructType;Lscala/Tuple2;)Lorg/apache/spark/sql/Row;")) {
                    StructType structType = (StructType) serializedLambda.getCapturedArg(0);
                    return tuple2 -> {
                        return tuple2._1() != null ? tuple2._2() != null ? (Row) tuple2._1() : SparkSchemaUtils.with((Row) tuple2._1(), ImmutableMap.of(STATE_COLUMN, CREATE_STATE)) : SparkSchemaUtils.conform((Row) tuple2._2(), structType);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("io/basestar/spark/PartitionedUpsertSink") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/spark/sql/types/StructType;Lorg/apache/spark/sql/Row;)Lorg/apache/spark/sql/Row;")) {
                    StructType structType2 = (StructType) serializedLambda.getCapturedArg(0);
                    return row4 -> {
                        return SparkSchemaUtils.conform(row4, structType2);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
