package uk.gov.gchq.gaffer.spark.operation.handler.graphframe;

import java.util.Set;
import java.util.stream.Collectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.functions;
import org.graphframes.GraphFrame;
import uk.gov.gchq.gaffer.data.element.ReservedPropertyNames;
import uk.gov.gchq.gaffer.operation.OperationException;
import uk.gov.gchq.gaffer.operation.io.Output;
import uk.gov.gchq.gaffer.spark.SparkContextUtil;
import uk.gov.gchq.gaffer.spark.operation.dataframe.GetDataFrameOfElements;
import uk.gov.gchq.gaffer.spark.operation.dataframe.converter.schema.SchemaToStructTypeConverter;
import uk.gov.gchq.gaffer.spark.operation.graphframe.GetGraphFrameOfElements;
import uk.gov.gchq.gaffer.spark.utils.scala.DataFrameUtil;
import uk.gov.gchq.gaffer.store.Context;
import uk.gov.gchq.gaffer.store.Store;
import uk.gov.gchq.gaffer.store.operation.handler.OutputOperationHandler;

/* loaded from: input_file:uk/gov/gchq/gaffer/spark/operation/handler/graphframe/GetGraphFrameOfElementsHandler.class */
public class GetGraphFrameOfElementsHandler implements OutputOperationHandler<GetGraphFrameOfElements, GraphFrame> {
    /* JADX WARN: Multi-variable type inference failed */
    @Override // uk.gov.gchq.gaffer.store.operation.handler.OutputOperationHandler, uk.gov.gchq.gaffer.store.operation.handler.OperationHandler
    public GraphFrame doOperation(GetGraphFrameOfElements getGraphFrameOfElements, Context context, Store store) throws OperationException {
        renameColumns((Dataset) store.execute((Output) ((GetDataFrameOfElements.Builder) new GetDataFrameOfElements.Builder().converters(getGraphFrameOfElements.getConverters()).view(getGraphFrameOfElements.getView())).options(getGraphFrameOfElements.getOptions()).build(), context)).createOrReplaceTempView("elements");
        String groupsToString = groupsToString(getGraphFrameOfElements.getView().getEdgeGroups());
        String groupsToString2 = groupsToString(getGraphFrameOfElements.getView().getEntityGroups());
        SparkSession sparkSession = SparkContextUtil.getSparkSession(context, store.getProperties());
        Dataset<Row> withColumn = sparkSession.sql("select * from elements where group in " + groupsToString).withColumn(SchemaToStructTypeConverter.ID, functions.row_number().over(Window.orderBy("group", new String[0]).partitionBy("group", new String[0])));
        Dataset<Row> sql = sparkSession.sql("select * from elements where group in " + groupsToString2);
        if (withColumn.rdd().isEmpty()) {
            withColumn = DataFrameUtil.emptyEdges(sparkSession);
        } else {
            sql = DataFrameUtil.union(sparkSession.sql("select src as vertex from elements where group in " + groupsToString).union(sparkSession.sql("select dst as vertex from elements where group in " + groupsToString)).distinct(), sql);
        }
        return GraphFrame.apply(sql.withColumnRenamed(SchemaToStructTypeConverter.VERTEX_COL_NAME, SchemaToStructTypeConverter.ID), withColumn);
    }

    private Dataset<Row> renameColumns(Dataset<Row> dataset) {
        return dataset.withColumnRenamed(ReservedPropertyNames.GROUP.name(), "group").withColumnRenamed(ReservedPropertyNames.ID.name(), SchemaToStructTypeConverter.ID).withColumnRenamed(ReservedPropertyNames.SOURCE.name(), SchemaToStructTypeConverter.SRC_COL_NAME).withColumnRenamed(ReservedPropertyNames.DESTINATION.name(), SchemaToStructTypeConverter.DST_COL_NAME).withColumnRenamed(ReservedPropertyNames.DIRECTED.name(), SchemaToStructTypeConverter.DIRECTED_COL_NAME).withColumnRenamed(ReservedPropertyNames.VERTEX.name(), SchemaToStructTypeConverter.VERTEX_COL_NAME).withColumnRenamed(ReservedPropertyNames.MATCHED_VERTEX.name(), SchemaToStructTypeConverter.MATCHED_VERTEX_COL_NAME);
    }

    private String groupsToString(Set<String> set) {
        return (String) set.stream().collect(Collectors.joining("','", "('", "')"));
    }
}
