package fr.mines_stetienne.ci.sparql_generate.engine;

import fr.mines_stetienne.ci.sparql_generate.SPARQLExtException;
import fr.mines_stetienne.ci.sparql_generate.query.SPARQLExtQuery;
import fr.mines_stetienne.ci.sparql_generate.query.SPARQLExtQueryVisitor;
import fr.mines_stetienne.ci.sparql_generate.utils.ContextUtils;
import fr.mines_stetienne.ci.sparql_generate.utils.LogUtils;
import fr.mines_stetienne.ci.sparql_generate.utils.VarUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.jena.query.Dataset;
import org.apache.jena.query.Query;
import org.apache.jena.query.ResultSet;
import org.apache.jena.query.ResultSetFactory;
import org.apache.jena.query.ResultSetRewindable;
import org.apache.jena.sparql.core.Var;
import org.apache.jena.sparql.engine.QueryEngineRegistry;
import org.apache.jena.sparql.engine.QueryExecutionBase;
import org.apache.jena.sparql.engine.binding.Binding;
import org.apache.jena.sparql.engine.binding.BindingHashMap;
import org.apache.jena.sparql.syntax.ElementData;
import org.apache.jena.sparql.syntax.ElementGroup;
import org.apache.jena.sparql.util.Context;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:fr/mines_stetienne/ci/sparql_generate/engine/SelectPlan.class */
public class SelectPlan {
    private static final Logger LOG = LoggerFactory.getLogger(SelectPlan.class);
    private final SPARQLExtQuery select;
    private final boolean isSelectType;
    private final List<Var> signature;

    public SelectPlan(SPARQLExtQuery sPARQLExtQuery, boolean z, List<Var> list) {
        if (!sPARQLExtQuery.isSelectType()) {
            throw new SPARQLExtException("Should be select query. " + sPARQLExtQuery);
        }
        this.select = sPARQLExtQuery;
        this.isSelectType = z;
        this.signature = list;
    }

    public List<Var> getVars() {
        return this.select.getProjectVars();
    }

    public final void exec(List<Var> list, List<Binding> list2, Context context, Consumer<ResultSet> consumer) {
        ResultSetRewindable copyResults;
        if (Thread.interrupted()) {
            throw new SPARQLExtException(new InterruptedException());
        }
        Query createQuery = createQuery(this.select, list, list2, context);
        Dataset dataset = ContextUtils.getDataset(context);
        if (LOG.isTraceEnabled()) {
            StringBuilder sb = new StringBuilder("Executing select query:\n");
            sb.append(createQuery.toString());
            if (list.size() <= 0 || list2.size() <= 0) {
                sb.append(" \nwithout initial values.");
            } else {
                sb.append(" \nwith initial values:\n");
                sb.append(LogUtils.log(list, list2));
            }
            LOG.trace(sb.toString());
        } else if (LOG.isDebugEnabled()) {
            LOG.debug("Executing select query with " + list2.size() + " bindings.");
        }
        try {
            augmentQuery(createQuery, list, list2);
            QueryExecutionBase queryExecutionBase = new QueryExecutionBase(createQuery, dataset, context, QueryEngineRegistry.get().find(createQuery, dataset.asDatasetGraph(), context));
            try {
                ResultSet execSelect = queryExecutionBase.execSelect();
                if (LOG.isTraceEnabled()) {
                    ResultSetRewindable copyResults2 = ResultSetFactory.copyResults(execSelect);
                    List<Var> variables = getVariables(copyResults2.getResultVars());
                    ArrayList arrayList = new ArrayList();
                    while (copyResults2.hasNext()) {
                        arrayList.add(copyResults2.nextBinding());
                    }
                    LOG.trace(String.format("Query output is\n%s", LogUtils.log(variables, arrayList)));
                    copyResults2.reset();
                    copyResults = copyResults2;
                } else if (LOG.isDebugEnabled()) {
                    ResultSetRewindable copyResults3 = ResultSetFactory.copyResults(execSelect);
                    int i = 0;
                    while (copyResults3.hasNext()) {
                        copyResults3.next();
                        i++;
                    }
                    LOG.debug(String.format("Query has %s output for variables %s", Integer.valueOf(i), copyResults3.getResultVars()));
                    copyResults3.reset();
                    copyResults = copyResults3;
                } else {
                    copyResults = ResultSetFactory.copyResults(execSelect);
                }
                consumer.accept(copyResults);
                queryExecutionBase.close();
            } finally {
            }
        } catch (Exception e) {
            LOG.error("Error while executing SELECT Query " + createQuery, e);
            throw new SPARQLExtException("Error while executing SELECT Query " + createQuery, e);
        }
    }

    private Query createQuery(SPARQLExtQuery sPARQLExtQuery, List<Var> list, List<Binding> list2, Context context) {
        SelectQueryPartialCopyVisitor selectQueryPartialCopyVisitor = new SelectQueryPartialCopyVisitor(!list2.isEmpty() ? list2.get(0) : null, context);
        sPARQLExtQuery.visit((SPARQLExtQueryVisitor) selectQueryPartialCopyVisitor);
        Query output = selectQueryPartialCopyVisitor.getOutput();
        if (!this.isSelectType && !output.hasGroupBy() && !output.hasAggregators()) {
            list.forEach(var -> {
                if (output.getProjectVars().contains(var)) {
                    return;
                }
                output.getProject().add(var);
            });
        }
        return output;
    }

    private void augmentQuery(Query query, List<Var> list, List<Binding> list2) {
        if (list.isEmpty()) {
            return;
        }
        ElementGroup queryPattern = query.getQueryPattern();
        ElementGroup elementGroup = new ElementGroup();
        query.setQueryPattern(elementGroup);
        if (queryPattern.size() >= 1 && (queryPattern.get(0) instanceof ElementData)) {
            ElementData elementData = (ElementData) queryPattern.get(0);
            int size = elementData.getRows().size();
            ElementData mergeValues = mergeValues(elementData, list, list2);
            elementGroup.addElement(mergeValues);
            for (int i = 1; i < queryPattern.size(); i++) {
                elementGroup.addElement(queryPattern.get(i));
            }
            LOG.debug("New query has " + mergeValues.getRows().size() + " initial values. It had " + size + " values before");
            return;
        }
        ElementData elementData2 = new ElementData();
        Objects.requireNonNull(elementData2);
        list.forEach(elementData2::add);
        Objects.requireNonNull(elementData2);
        list2.forEach(elementData2::add);
        elementGroup.addElement(elementData2);
        List elements = queryPattern.getElements();
        Objects.requireNonNull(elementGroup);
        elements.forEach(elementGroup::addElement);
        check(elementData2, list2);
    }

    private ElementData mergeValues(ElementData elementData, List<Var> list, List<Binding> list2) {
        if (list2.isEmpty()) {
            return elementData;
        }
        List vars = elementData.getVars();
        if (!Collections.disjoint(vars, list)) {
            throw new SPARQLExtException("Variables " + vars.retainAll(list) + "were already bound.");
        }
        ElementData elementData2 = new ElementData();
        List vars2 = elementData.getVars();
        Objects.requireNonNull(elementData2);
        vars2.forEach(elementData2::add);
        Objects.requireNonNull(elementData2);
        list.forEach(elementData2::add);
        elementData.getRows().forEach(binding -> {
            list2.forEach(binding -> {
                BindingHashMap bindingHashMap = new BindingHashMap(binding);
                list.forEach(var -> {
                    bindingHashMap.add(var, binding.get(var));
                });
                elementData2.add(bindingHashMap);
            });
        });
        return elementData2;
    }

    private void check(ElementData elementData, Collection<Binding> collection) {
        if (elementData.getRows().size() != collection.size()) {
            LOG.warn("Different size for the values block here.\n Was " + collection.size() + ": \n" + collection + "\n now is " + elementData.getRows().size() + ": \n" + elementData.getRows());
            StringBuilder sb = new StringBuilder("Different size for the values block here.\n Was " + collection.size() + ": \n" + collection + "\n\n");
            int i = 0;
            Iterator<Binding> it = collection.iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                sb.append("\nbinding ").append(i2).append(" is ").append(it.next());
            }
            LOG.warn(sb.toString());
        }
    }

    private List<Var> getVariables(List<String> list) {
        return (List) list.stream().map(VarUtils::allocVar).collect(Collectors.toList());
    }
}
