package uk.gov.gchq.gaffer.spark.operation.dataframe;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.sources.And;
import org.apache.spark.sql.sources.EqualTo;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.GreaterThan;
import org.apache.spark.sql.sources.LessThan;
import org.apache.spark.sql.sources.Or;
import org.junit.Assert;
import org.junit.Test;
import uk.gov.gchq.gaffer.commonutil.StreamUtil;
import uk.gov.gchq.gaffer.data.elementdefinition.view.View;
import uk.gov.gchq.gaffer.operation.OperationException;
import uk.gov.gchq.gaffer.operation.data.EntitySeed;
import uk.gov.gchq.gaffer.operation.graph.GraphFilters;
import uk.gov.gchq.gaffer.spark.SparkConstants;
import uk.gov.gchq.gaffer.spark.operation.scalardd.GetRDDOfAllElements;
import uk.gov.gchq.gaffer.spark.operation.scalardd.GetRDDOfElements;
import uk.gov.gchq.gaffer.store.schema.Schema;
import uk.gov.gchq.koryphe.impl.predicate.IsLessThan;
import uk.gov.gchq.koryphe.impl.predicate.IsMoreThan;
import uk.gov.gchq.koryphe.tuple.predicate.TupleAdaptedPredicate;

/* loaded from: input_file:uk/gov/gchq/gaffer/spark/operation/dataframe/FilterToOperationConverterTest.class */
public class FilterToOperationConverterTest {
    private static final String ENTITY_GROUP = "BasicEntity";
    private static final String EDGE_GROUP = "BasicEdge";
    private static final String EDGE_GROUP2 = "BasicEdge2";
    private static final Set<String> EDGE_GROUPS = new HashSet(Arrays.asList(EDGE_GROUP, EDGE_GROUP2));

    @Test
    public void testIncompatibleGroups() throws OperationException {
        Schema schema = getSchema();
        SparkSession orCreate = SparkSession.builder().config(getSparkConf("testIncompatibleGroups")).getOrCreate();
        Assert.assertNull(new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new EqualTo("group", "A"), new EqualTo("group", "B")}).getOperation());
        orCreate.sparkContext().stop();
    }

    @Test
    public void testSingleGroup() throws OperationException {
        Schema schema = getSchema();
        SparkSession orCreate = SparkSession.builder().config(getSparkConf("testSingleGroup")).getOrCreate();
        GraphFilters operation = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new EqualTo("group", ENTITY_GROUP)}).getOperation();
        Assert.assertTrue(operation instanceof GetRDDOfAllElements);
        Assert.assertEquals(Collections.singleton(ENTITY_GROUP), operation.getView().getEntityGroups());
        Assert.assertEquals(0L, operation.getView().getEdgeGroups().size());
        orCreate.sparkContext().stop();
    }

    @Test
    public void testSingleGroupNotInSchema() throws OperationException {
        Schema schema = getSchema();
        SparkSession orCreate = SparkSession.builder().config(getSparkConf("testSingleGroupNotInSchema")).getOrCreate();
        Assert.assertNull(new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new EqualTo("group", "random")}).getOperation());
        orCreate.sparkContext().stop();
    }

    @Test
    public void testTwoGroups() throws OperationException {
        Schema schema = getSchema();
        SparkSession orCreate = SparkSession.builder().config(getSparkConf("testTwoGroups")).getOrCreate();
        GraphFilters operation = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new Or(new EqualTo("group", ENTITY_GROUP), new EqualTo("group", EDGE_GROUP2))}).getOperation();
        Assert.assertTrue(operation instanceof GetRDDOfAllElements);
        Assert.assertEquals(Collections.singleton(ENTITY_GROUP), operation.getView().getEntityGroups());
        Assert.assertEquals(Collections.singleton(EDGE_GROUP2), operation.getView().getEdgeGroups());
        orCreate.sparkContext().stop();
    }

    @Test
    public void testSpecifyVertex() throws OperationException {
        Schema schema = getSchema();
        SparkSession orCreate = SparkSession.builder().config(getSparkConf("testSpecifyVertex")).getOrCreate();
        GetRDDOfElements operation = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new EqualTo("vertex", "0")}).getOperation();
        Assert.assertTrue(operation instanceof GetRDDOfElements);
        Assert.assertEquals(Collections.singleton(ENTITY_GROUP), ((GraphFilters) operation).getView().getEntityGroups());
        Assert.assertEquals(0L, ((GraphFilters) operation).getView().getEdgeGroups().size());
        HashSet hashSet = new HashSet();
        Iterator it = operation.getInput().iterator();
        while (it.hasNext()) {
            hashSet.add((EntitySeed) it.next());
        }
        Assert.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
        orCreate.sparkContext().stop();
    }

    @Test
    public void testSpecifySource() throws OperationException {
        Schema schema = getSchema();
        SparkSession orCreate = SparkSession.builder().config(getSparkConf("testSpecifySource")).getOrCreate();
        GetRDDOfElements operation = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new EqualTo("src", "0")}).getOperation();
        Assert.assertTrue(operation instanceof GetRDDOfElements);
        Assert.assertEquals(0L, ((GraphFilters) operation).getView().getEntityGroups().size());
        Assert.assertEquals(EDGE_GROUPS, ((GraphFilters) operation).getView().getEdgeGroups());
        HashSet hashSet = new HashSet();
        Iterator it = operation.getInput().iterator();
        while (it.hasNext()) {
            hashSet.add((EntitySeed) it.next());
        }
        Assert.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
        orCreate.sparkContext().stop();
    }

    @Test
    public void testSpecifyDestination() throws OperationException {
        Schema schema = getSchema();
        SparkSession orCreate = SparkSession.builder().config(getSparkConf("testSpecifyDestination")).getOrCreate();
        GetRDDOfElements operation = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new EqualTo("dst", "0")}).getOperation();
        Assert.assertTrue(operation instanceof GetRDDOfElements);
        Assert.assertEquals(0L, ((GraphFilters) operation).getView().getEntityGroups().size());
        Assert.assertEquals(EDGE_GROUPS, ((GraphFilters) operation).getView().getEdgeGroups());
        HashSet hashSet = new HashSet();
        Iterator it = operation.getInput().iterator();
        while (it.hasNext()) {
            hashSet.add((EntitySeed) it.next());
        }
        Assert.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
        orCreate.sparkContext().stop();
    }

    @Test
    public void testSpecifyPropertyFilters() throws OperationException {
        Schema schema = getSchema();
        SparkSession orCreate = SparkSession.builder().config(getSparkConf("testSpecifyPropertyFilters")).getOrCreate();
        Filter[] filterArr = {new GreaterThan("property1", 5)};
        GraphFilters operation = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, filterArr).getOperation();
        Assert.assertTrue(operation instanceof GetRDDOfAllElements);
        View view = operation.getView();
        List postAggregationFilterFunctions = view.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        Assert.assertEquals(1L, postAggregationFilterFunctions.size());
        Assert.assertArrayEquals(new String[]{"property1"}, ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection());
        Assert.assertEquals(new IsMoreThan(5, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getPredicate());
        Iterator<String> it = EDGE_GROUPS.iterator();
        while (it.hasNext()) {
            List postAggregationFilterFunctions2 = view.getEdge(it.next()).getPostAggregationFilterFunctions();
            Assert.assertEquals(1L, postAggregationFilterFunctions2.size());
            Assert.assertArrayEquals(new String[]{"property1"}, ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection());
            Assert.assertEquals(new IsMoreThan(5, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getPredicate());
        }
        filterArr[0] = new LessThan("property4", 8L);
        GraphFilters operation2 = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, filterArr).getOperation();
        Assert.assertTrue(operation2 instanceof GetRDDOfAllElements);
        View view2 = operation2.getView();
        List postAggregationFilterFunctions3 = view2.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        Assert.assertEquals(1L, postAggregationFilterFunctions3.size());
        Assert.assertArrayEquals(new String[]{"property4"}, ((TupleAdaptedPredicate) postAggregationFilterFunctions3.get(0)).getSelection());
        Assert.assertEquals(new IsLessThan(8L, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions3.get(0)).getPredicate());
        List postAggregationFilterFunctions4 = view2.getEdge(EDGE_GROUP).getPostAggregationFilterFunctions();
        Assert.assertEquals(1L, postAggregationFilterFunctions4.size());
        Assert.assertArrayEquals(new String[]{"property4"}, ((TupleAdaptedPredicate) postAggregationFilterFunctions4.get(0)).getSelection());
        Assert.assertEquals(new IsLessThan(8L, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions4.get(0)).getPredicate());
        filterArr[0] = new And(new GreaterThan("property1", 5), new GreaterThan("property4", 8L));
        GraphFilters operation3 = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, filterArr).getOperation();
        Assert.assertTrue(operation3 instanceof GetRDDOfAllElements);
        View view3 = operation3.getView();
        List postAggregationFilterFunctions5 = view3.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        Assert.assertEquals(2L, postAggregationFilterFunctions5.size());
        ArrayList arrayList = new ArrayList();
        arrayList.add("property1");
        arrayList.add("property4");
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(0)).getSelection()).length);
        Assert.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(0)).getSelection())[0]);
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(1)).getSelection()).length);
        Assert.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(1)).getSelection())[0]);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new IsMoreThan(5, false));
        arrayList2.add(new IsMoreThan(8L, false));
        Assert.assertEquals(arrayList2.get(0), ((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(0)).getPredicate());
        Assert.assertEquals(arrayList2.get(1), ((TupleAdaptedPredicate) postAggregationFilterFunctions5.get(1)).getPredicate());
        List postAggregationFilterFunctions6 = view3.getEdge(EDGE_GROUP).getPostAggregationFilterFunctions();
        Assert.assertEquals(2L, postAggregationFilterFunctions6.size());
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions6.get(0)).getSelection()).length);
        Assert.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions6.get(0)).getSelection())[0]);
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions6.get(1)).getSelection()).length);
        Assert.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions6.get(1)).getSelection())[0]);
        orCreate.sparkContext().stop();
    }

    @Test
    public void testSpecifyMultiplePropertyFilters() throws OperationException {
        Schema schema = getSchema();
        SparkSession orCreate = SparkSession.builder().config(getSparkConf("testSpecifyMultiplePropertyFilters")).getOrCreate();
        GraphFilters operation = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new GreaterThan("property1", 5), new LessThan("property4", 8L)}).getOperation();
        Assert.assertTrue(operation instanceof GetRDDOfAllElements);
        View view = operation.getView();
        List postAggregationFilterFunctions = view.getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        Assert.assertEquals(2L, postAggregationFilterFunctions.size());
        ArrayList arrayList = new ArrayList();
        arrayList.add("property1");
        arrayList.add("property4");
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection()).length);
        Assert.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection())[0]);
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(1)).getSelection()).length);
        Assert.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(1)).getSelection())[0]);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new IsMoreThan(5, false));
        arrayList2.add(new IsLessThan(8L, false));
        Assert.assertEquals(arrayList2.get(0), ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getPredicate());
        Assert.assertEquals(arrayList2.get(1), ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(1)).getPredicate());
        List postAggregationFilterFunctions2 = view.getEdge(EDGE_GROUP).getPostAggregationFilterFunctions();
        Assert.assertEquals(2L, postAggregationFilterFunctions2.size());
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection()).length);
        Assert.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection())[0]);
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection()).length);
        Assert.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection())[0]);
        orCreate.sparkContext().stop();
    }

    @Test
    public void testSpecifyVertexAndPropertyFilter() throws OperationException {
        Schema schema = getSchema();
        SparkSession orCreate = SparkSession.builder().config(getSparkConf("testSpecifyVertexAndPropertyFilter")).getOrCreate();
        GraphFilters operation = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new GreaterThan("property1", 5), new EqualTo("vertex", "0")}).getOperation();
        Assert.assertTrue(operation instanceof GetRDDOfElements);
        Assert.assertEquals(1L, operation.getView().getEntityGroups().size());
        Assert.assertEquals(0L, operation.getView().getEdgeGroups().size());
        HashSet hashSet = new HashSet();
        Iterator it = ((GetRDDOfElements) operation).getInput().iterator();
        while (it.hasNext()) {
            hashSet.add((EntitySeed) it.next());
        }
        Assert.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
        List postAggregationFilterFunctions = operation.getView().getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        Assert.assertEquals(1L, postAggregationFilterFunctions.size());
        ArrayList arrayList = new ArrayList();
        arrayList.add("property1");
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection()).length);
        Assert.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection())[0]);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new IsMoreThan(5, false));
        Assert.assertEquals(arrayList2.get(0), ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getPredicate());
        GraphFilters operation2 = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new GreaterThan("property1", 5), new EqualTo("vertex", "0"), new LessThan("property4", 8)}).getOperation();
        Assert.assertTrue(operation2 instanceof GetRDDOfElements);
        Assert.assertEquals(1L, operation2.getView().getEntityGroups().size());
        Assert.assertEquals(0L, operation2.getView().getEdgeGroups().size());
        hashSet.clear();
        Iterator it2 = ((GetRDDOfElements) operation2).getInput().iterator();
        while (it2.hasNext()) {
            hashSet.add((EntitySeed) it2.next());
        }
        Assert.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
        List postAggregationFilterFunctions2 = operation2.getView().getEntity(ENTITY_GROUP).getPostAggregationFilterFunctions();
        Assert.assertEquals(2L, postAggregationFilterFunctions2.size());
        arrayList.clear();
        arrayList.add("property1");
        arrayList.add("property4");
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection()).length);
        Assert.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection())[0]);
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection()).length);
        Assert.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection())[0]);
        arrayList2.clear();
        arrayList2.add(new IsMoreThan(5, false));
        arrayList2.add(new IsLessThan(8, false));
        Assert.assertEquals(arrayList2.get(0), ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getPredicate());
        Assert.assertEquals(arrayList2.get(1), ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getPredicate());
        orCreate.sparkContext().stop();
    }

    @Test
    public void testSpecifySourceOrDestinationAndPropertyFilter() throws OperationException {
        Schema schema = getSchema();
        SparkSession orCreate = SparkSession.builder().config(getSparkConf("testSpecifyVertexAndPropertyFilter")).getOrCreate();
        GraphFilters operation = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new GreaterThan("property1", 5), new EqualTo("src", "0")}).getOperation();
        Assert.assertTrue(operation instanceof GetRDDOfElements);
        Assert.assertEquals(0L, operation.getView().getEntityGroups().size());
        Assert.assertEquals(2L, operation.getView().getEdgeGroups().size());
        HashSet hashSet = new HashSet();
        Iterator it = ((GetRDDOfElements) operation).getInput().iterator();
        while (it.hasNext()) {
            hashSet.add((EntitySeed) it.next());
        }
        Assert.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
        View view = operation.getView();
        Iterator<String> it2 = EDGE_GROUPS.iterator();
        while (it2.hasNext()) {
            List postAggregationFilterFunctions = view.getEdge(it2.next()).getPostAggregationFilterFunctions();
            Assert.assertEquals(1L, postAggregationFilterFunctions.size());
            Assert.assertArrayEquals(new String[]{"property1"}, ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getSelection());
            Assert.assertEquals(new IsMoreThan(5, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions.get(0)).getPredicate());
        }
        GraphFilters operation2 = new FiltersToOperationConverter(orCreate, getViewFromSchema(schema), schema, new Filter[]{new GreaterThan("property1", 5), new EqualTo("src", "0"), new LessThan("property4", 8)}).getOperation();
        Assert.assertTrue(operation2 instanceof GetRDDOfElements);
        Assert.assertEquals(0L, operation2.getView().getEntityGroups().size());
        Assert.assertEquals(1L, operation2.getView().getEdgeGroups().size());
        hashSet.clear();
        Iterator it3 = ((GetRDDOfElements) operation2).getInput().iterator();
        while (it3.hasNext()) {
            hashSet.add((EntitySeed) it3.next());
        }
        Assert.assertEquals(Collections.singleton(new EntitySeed("0")), hashSet);
        List postAggregationFilterFunctions2 = operation2.getView().getEdge(EDGE_GROUP).getPostAggregationFilterFunctions();
        Assert.assertEquals(2L, postAggregationFilterFunctions2.size());
        ArrayList arrayList = new ArrayList();
        arrayList.add("property1");
        arrayList.add("property4");
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection()).length);
        Assert.assertEquals(arrayList.get(0), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getSelection())[0]);
        Assert.assertEquals(new IsMoreThan(5, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(0)).getPredicate());
        Assert.assertEquals(1L, ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection()).length);
        Assert.assertEquals(arrayList.get(1), ((String[]) ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getSelection())[0]);
        Assert.assertEquals(new IsLessThan(8, false), ((TupleAdaptedPredicate) postAggregationFilterFunctions2.get(1)).getPredicate());
        orCreate.sparkContext().stop();
    }

    private Schema getSchema() {
        return Schema.fromJson(StreamUtil.schemas(getClass()));
    }

    private View getViewFromSchema(Schema schema) {
        return new View.Builder().entities(schema.getEntityGroups()).edges(schema.getEdgeGroups()).build();
    }

    private SparkConf getSparkConf(String str) {
        return new SparkConf().setMaster("local").setAppName(str).set("spark.serializer", SparkConstants.DEFAULT_SERIALIZER).set("spark.kryo.registrator", SparkConstants.DEFAULT_KRYO_REGISTRATOR).set("spark.driver.allowMultipleContexts", "true");
    }
}
