package uk.gov.gchq.gaffer.spark.operation.dataframe;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.spark.sql.sources.And;
import org.apache.spark.sql.sources.EqualTo;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.GreaterThan;
import org.apache.spark.sql.sources.LessThan;
import org.apache.spark.sql.sources.Or;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import uk.gov.gchq.gaffer.commonutil.StreamUtil;
import uk.gov.gchq.gaffer.data.elementdefinition.view.View;
import uk.gov.gchq.gaffer.operation.data.EntitySeed;
import uk.gov.gchq.gaffer.operation.graph.GraphFilters;
import uk.gov.gchq.gaffer.spark.SparkSessionProvider;
import uk.gov.gchq.gaffer.spark.operation.scalardd.GetRDDOfAllElements;
import uk.gov.gchq.gaffer.spark.operation.scalardd.GetRDDOfElements;
import uk.gov.gchq.gaffer.store.schema.Schema;
import uk.gov.gchq.koryphe.impl.predicate.IsLessThan;
import uk.gov.gchq.koryphe.impl.predicate.IsMoreThan;
import uk.gov.gchq.koryphe.tuple.predicate.TupleAdaptedPredicate;

/* loaded from: input_file:uk/gov/gchq/gaffer/spark/operation/dataframe/FilterToOperationConverterTest.class */
public class FilterToOperationConverterTest {
    private static final String ENTITY_GROUP = "BasicEntity";
    private static final String EDGE_GROUP = "BasicEdge";
    private static final String EDGE_GROUP2 = "BasicEdge2";
    private static final Set<String> EDGE_GROUPS = new HashSet(Arrays.asList(EDGE_GROUP, EDGE_GROUP2));

    @Test
    public void testIncompatibleGroups() {
        Schema schema = getSchema();
        SparkSessionProvider.getSparkSession();
        Assertions.assertNull(new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new EqualTo("group", "A"), new EqualTo("group", "B")}).getOperation());
    }

    @Test
    public void testSingleGroup() {
        Schema schema = getSchema();
        SparkSessionProvider.getSparkSession();
        GraphFilters operation = new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new EqualTo("group", ENTITY_GROUP)}).getOperation();
        Assertions.assertTrue(operation instanceof GetRDDOfAllElements);
        Assertions.assertEquals(Collections.singleton(ENTITY_GROUP), operation.getView().getEntityGroups());
        Assertions.assertEquals(0, operation.getView().getEdgeGroups().size());
    }

    @Test
    public void testSingleGroupNotInSchema() {
        Schema schema = getSchema();
        SparkSessionProvider.getSparkSession();
        Assertions.assertNull(new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new EqualTo("group", "random")}).getOperation());
    }

    @Test
    public void testTwoGroups() {
        Schema schema = getSchema();
        SparkSessionProvider.getSparkSession();
        GraphFilters operation = new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new Or(new EqualTo("group", ENTITY_GROUP), new EqualTo("group", EDGE_GROUP2))}).getOperation();
        Assertions.assertTrue(operation instanceof GetRDDOfAllElements);
        Assertions.assertEquals(Collections.singleton(ENTITY_GROUP), operation.getView().getEntityGroups());
        Assertions.assertEquals(Collections.singleton(EDGE_GROUP2), operation.getView().getEdgeGroups());
    }

    @Test
    public void testSpecifyVertex() {
        Schema schema = getSchema();
        SparkSessionProvider.getSparkSession();
        GetRDDOfElements operation = new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new EqualTo("vertex", "0")}).getOperation();
        Assertions.assertTrue(operation instanceof GetRDDOfElements);
        Assertions.assertEquals(Collections.singleton(ENTITY_GROUP), ((GraphFilters) operation).getView().getEntityGroups());
        Assertions.assertEquals(0, ((GraphFilters) operation).getView().getEdgeGroups().size());
        HashSet hashSet = new HashSet();
        Iterator it = operation.getInput().iterator();
        while (it.hasNext()) {
            hashSet.add((EntitySeed) it.next());
        }
        Assertions.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
    }

    @Test
    public void testSpecifySource() {
        Schema schema = getSchema();
        SparkSessionProvider.getSparkSession();
        GetRDDOfElements operation = new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new EqualTo("src", "0")}).getOperation();
        Assertions.assertTrue(operation instanceof GetRDDOfElements);
        Assertions.assertEquals(0, ((GraphFilters) operation).getView().getEntityGroups().size());
        Assertions.assertEquals(EDGE_GROUPS, ((GraphFilters) operation).getView().getEdgeGroups());
        HashSet hashSet = new HashSet();
        Iterator it = operation.getInput().iterator();
        while (it.hasNext()) {
            hashSet.add((EntitySeed) it.next());
        }
        Assertions.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
    }

    @Test
    public void testSpecifyDestination() {
        Schema schema = getSchema();
        SparkSessionProvider.getSparkSession();
        GetRDDOfElements operation = new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new EqualTo("dst", "0")}).getOperation();
        Assertions.assertTrue(operation instanceof GetRDDOfElements);
        Assertions.assertEquals(0, ((GraphFilters) operation).getView().getEntityGroups().size());
        Assertions.assertEquals(EDGE_GROUPS, ((GraphFilters) operation).getView().getEdgeGroups());
        HashSet hashSet = new HashSet();
        Iterator it = operation.getInput().iterator();
        while (it.hasNext()) {
            hashSet.add((EntitySeed) it.next());
        }
        Assertions.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
    }

    @Test
    public void testSpecifyPropertyFilters() {
        Schema schema = getSchema();
        SparkSessionProvider.getSparkSession();
        Filter[] filterArr = {new GreaterThan("property1", 5)};
        GraphFilters operation = new FiltersToOperationConverter(getViewFromSchema(schema), schema, filterArr).getOperation();
        Assertions.assertTrue(operation instanceof GetRDDOfAllElements);
        View view = operation.getView();
        List postAggregationFilterFunctions = view.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions).hasSize(1);
        Assertions.assertArrayEquals(new String[]{"property1"}, ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection());
        Assertions.assertEquals(new IsMoreThan(5, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getPredicate());
        Iterator<String> it = EDGE_GROUPS.iterator();
        while (it.hasNext()) {
            List postAggregationFilterFunctions2 = view.getEdge(it.next()).getPostAggregationFilterFunctions();
            org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions2).hasSize(1);
            Assertions.assertArrayEquals(new String[]{"property1"}, ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection());
            Assertions.assertEquals(new IsMoreThan(5, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getPredicate());
        }
        filterArr[0] = new LessThan("property4", 8L);
        GraphFilters operation2 = new FiltersToOperationConverter(getViewFromSchema(schema), schema, filterArr).getOperation();
        Assertions.assertTrue(operation2 instanceof GetRDDOfAllElements);
        View view2 = operation2.getView();
        List postAggregationFilterFunctions3 = view2.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions3).hasSize(1);
        Assertions.assertArrayEquals(new String[]{"property4"}, ((TupleAdaptedPredicate) postAggregationFilterFunctions3.get(0)).getSelection());
        Assertions.assertEquals(new IsLessThan(8L, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions3.get(0)).getPredicate());
        List postAggregationFilterFunctions4 = view2.getEdge(EDGE_GROUP).getPostAggregationFilterFunctions();
        org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions4).hasSize(1);
        Assertions.assertArrayEquals(new String[]{"property4"}, ((TupleAdaptedPredicate) postAggregationFilterFunctions4.get(0)).getSelection());
        Assertions.assertEquals(new IsLessThan(8L, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions4.get(0)).getPredicate());
        filterArr[0] = new And(new GreaterThan("property1", 5), new GreaterThan("property4", 8L));
        GraphFilters operation3 = new FiltersToOperationConverter(getViewFromSchema(schema), schema, filterArr).getOperation();
        Assertions.assertTrue(operation3 instanceof GetRDDOfAllElements);
        View view3 = operation3.getView();
        List postAggregationFilterFunctions5 = view3.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions5).hasSize(2);
        ArrayList arrayList = new ArrayList();
        arrayList.add("property1");
        arrayList.add("property4");
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(0)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(0)).getSelection())[0]);
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(1)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(1)).getSelection())[0]);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new IsMoreThan(5, false));
        arrayList2.add(new IsMoreThan(8L, false));
        Assertions.assertEquals(arrayList2.get(0), ((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(0)).getPredicate());
        Assertions.assertEquals(arrayList2.get(1), ((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(1)).getPredicate());
        List postAggregationFilterFunctions6 = view3.getEdge(EDGE_GROUP).getPostAggregationFilterFunctions();
        org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions6).hasSize(2);
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions6.get(0)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions6.get(0)).getSelection())[0]);
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions6.get(1)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions6.get(1)).getSelection())[0]);
    }

    @Test
    public void testSpecifyMultiplePropertyFilters() {
        Schema schema = getSchema();
        SparkSessionProvider.getSparkSession();
        GraphFilters operation = new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new GreaterThan("property1", 5), new LessThan("property4", 8L)}).getOperation();
        Assertions.assertTrue(operation instanceof GetRDDOfAllElements);
        View view = operation.getView();
        List postAggregationFilterFunctions = view.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions).hasSize(2);
        ArrayList arrayList = new ArrayList();
        arrayList.add("property1");
        arrayList.add("property4");
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection())[0]);
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions.get(1)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(1)).getSelection())[0]);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new IsMoreThan(5, false));
        arrayList2.add(new IsLessThan(8L, false));
        Assertions.assertEquals(arrayList2.get(0), ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getPredicate());
        Assertions.assertEquals(arrayList2.get(1), ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(1)).getPredicate());
        List postAggregationFilterFunctions2 = view.getEdge(EDGE_GROUP).getPostAggregationFilterFunctions();
        org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions2).hasSize(2);
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection())[0]);
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection())[0]);
    }

    @Test
    public void testSpecifyVertexAndPropertyFilter() {
        Schema schema = getSchema();
        SparkSessionProvider.getSparkSession();
        GraphFilters operation = new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new GreaterThan("property1", 5), new EqualTo("vertex", "0")}).getOperation();
        Assertions.assertTrue(operation instanceof GetRDDOfElements);
        Assertions.assertEquals(1, operation.getView().getEntityGroups().size());
        Assertions.assertEquals(0, operation.getView().getEdgeGroups().size());
        HashSet hashSet = new HashSet();
        Iterator it = ((GetRDDOfElements) operation).getInput().iterator();
        while (it.hasNext()) {
            hashSet.add((EntitySeed) it.next());
        }
        Assertions.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
        List postAggregationFilterFunctions = operation.getView().getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions).hasSize(1);
        ArrayList arrayList = new ArrayList();
        arrayList.add("property1");
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection())[0]);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new IsMoreThan(5, false));
        Assertions.assertEquals(arrayList2.get(0), ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getPredicate());
        GraphFilters operation2 = new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new GreaterThan("property1", 5), new EqualTo("vertex", "0"), new LessThan("property4", 8)}).getOperation();
        Assertions.assertTrue(operation2 instanceof GetRDDOfElements);
        Assertions.assertEquals(1, operation2.getView().getEntityGroups().size());
        Assertions.assertEquals(0, operation2.getView().getEdgeGroups().size());
        hashSet.clear();
        Iterator it2 = ((GetRDDOfElements) operation2).getInput().iterator();
        while (it2.hasNext()) {
            hashSet.add((EntitySeed) it2.next());
        }
        Assertions.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
        List postAggregationFilterFunctions2 = operation2.getView().getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions2).hasSize(2);
        arrayList.clear();
        arrayList.add("property1");
        arrayList.add("property4");
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection())[0]);
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection())[0]);
        arrayList2.clear();
        arrayList2.add(new IsMoreThan(5, false));
        arrayList2.add(new IsLessThan(8, false));
        Assertions.assertEquals(arrayList2.get(0), ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getPredicate());
        Assertions.assertEquals(arrayList2.get(1), ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getPredicate());
    }

    @Test
    public void testSpecifySourceOrDestinationAndPropertyFilter() {
        Schema schema = getSchema();
        SparkSessionProvider.getSparkSession();
        GraphFilters operation = new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new GreaterThan("property1", 5), new EqualTo("src", "0")}).getOperation();
        Assertions.assertTrue(operation instanceof GetRDDOfElements);
        Assertions.assertEquals(0, operation.getView().getEntityGroups().size());
        Assertions.assertEquals(2, operation.getView().getEdgeGroups().size());
        HashSet hashSet = new HashSet();
        Iterator it = ((GetRDDOfElements) operation).getInput().iterator();
        while (it.hasNext()) {
            hashSet.add((EntitySeed) it.next());
        }
        Assertions.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
        View view = operation.getView();
        Iterator<String> it2 = EDGE_GROUPS.iterator();
        while (it2.hasNext()) {
            List postAggregationFilterFunctions = view.getEdge(it2.next()).getPostAggregationFilterFunctions();
            org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions).hasSize(1);
            Assertions.assertArrayEquals(new String[]{"property1"}, ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection());
            Assertions.assertEquals(new IsMoreThan(5, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getPredicate());
        }
        GraphFilters operation2 = new FiltersToOperationConverter(getViewFromSchema(schema), schema, new Filter[]{new GreaterThan("property1", 5), new EqualTo("src", "0"), new LessThan("property4", 8)}).getOperation();
        Assertions.assertTrue(operation2 instanceof GetRDDOfElements);
        Assertions.assertEquals(0, operation2.getView().getEntityGroups().size());
        Assertions.assertEquals(1, operation2.getView().getEdgeGroups().size());
        hashSet.clear();
        Iterator it3 = ((GetRDDOfElements) operation2).getInput().iterator();
        while (it3.hasNext()) {
            hashSet.add((EntitySeed) it3.next());
        }
        Assertions.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
        List postAggregationFilterFunctions2 = operation2.getView().getEdge(EDGE_GROUP).getPostAggregationFilterFunctions();
        org.assertj.core.api.Assertions.assertThat(postAggregationFilterFunctions2).hasSize(2);
        ArrayList arrayList = new ArrayList();
        arrayList.add("property1");
        arrayList.add("property4");
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection())[0]);
        Assertions.assertEquals(new IsMoreThan(5, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getPredicate());
        org.assertj.core.api.Assertions.assertThat(((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection()).hasSize(1);
        Assertions.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection())[0]);
        Assertions.assertEquals(new IsLessThan(8, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getPredicate());
    }

    private Schema getSchema() {
        return Schema.fromJson(StreamUtil.schemas(getClass()));
    }

    private View getViewFromSchema(Schema schema) {
        return new View.Builder().entities(schema.getEntityGroups()).edges(schema.getEdgeGroups()).build();
    }
}
