package io.substrait.dsl;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableExpression;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ToTypeString;
import io.substrait.plan.ImmutablePlan;
import io.substrait.plan.ImmutableRoot;
import io.substrait.plan.Plan;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.Join;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.ImmutableType;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/substrait/dsl/SubstraitBuilder.class */
public class SubstraitBuilder {
    static final TypeCreator R = TypeCreator.of(false);
    static final TypeCreator N = TypeCreator.of(true);
    private static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml";
    private static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml";
    private static final String FUNCTIONS_BOOLEAN = "/functions_boolean.yaml";
    private static final String FUNCTIONS_COMPARISON = "/functions_comparison.yaml";
    private final SimpleExtension.ExtensionCollection extensions;

    /* loaded from: input_file:io/substrait/dsl/SubstraitBuilder$JoinInput.class */
    public static final class JoinInput {
        private final Rel left;
        private final Rel right;

        public JoinInput(Rel rel, Rel rel2) {
            this.left = rel;
            this.right = rel2;
        }

        public String toString() {
            return "JoinInput[left=" + this.left + ",right=" + this.right + "]";
        }

        public int hashCode() {
            return (31 * ((31 * 0) + (this.left != null ? this.left.hashCode() : 0))) + (this.right != null ? this.right.hashCode() : 0);
        }

        public final boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            return obj != null && obj.getClass() == getClass() && Objects.equals(((JoinInput) obj).left, this.left) && Objects.equals(((JoinInput) obj).right, this.right);
        }

        public Rel left() {
            return this.left;
        }

        public Rel right() {
            return this.right;
        }
    }

    public SubstraitBuilder(SimpleExtension.ExtensionCollection extensionCollection) {
        this.extensions = extensionCollection;
    }

    public Aggregate aggregate(Function<Rel, Aggregate.Grouping> function, Function<Rel, List<AggregateFunctionInvocation>> function2, Rel rel) {
        return aggregate(function.andThen(grouping -> {
            return (List) Stream.of(grouping).collect(Collectors.toList());
        }), function2, Optional.empty(), rel);
    }

    public Aggregate aggregate(Function<Rel, Aggregate.Grouping> function, Function<Rel, List<AggregateFunctionInvocation>> function2, Rel.Remap remap, Rel rel) {
        return aggregate(function.andThen(grouping -> {
            return (List) Stream.of(grouping).collect(Collectors.toList());
        }), function2, Optional.of(remap), rel);
    }

    private Aggregate aggregate(Function<Rel, List<Aggregate.Grouping>> function, Function<Rel, List<AggregateFunctionInvocation>> function2, Optional<Rel.Remap> optional, Rel rel) {
        List<Aggregate.Grouping> apply = function.apply(rel);
        return Aggregate.builder().groupings(apply).measures((List) function2.apply(rel).stream().map(aggregateFunctionInvocation -> {
            return Aggregate.Measure.builder().function(aggregateFunctionInvocation).build();
        }).collect(Collectors.toList())).remap(optional).input(rel).build();
    }

    public Cross cross(Rel rel, Rel rel2) {
        return cross(rel, rel2, Optional.empty());
    }

    public Cross cross(Rel rel, Rel rel2, Rel.Remap remap) {
        return cross(rel, rel2, Optional.of(remap));
    }

    private Cross cross(Rel rel, Rel rel2, Optional<Rel.Remap> optional) {
        return Cross.builder().left(rel).right(rel2).remap(optional).build();
    }

    public Fetch fetch(long j, long j2, Rel rel) {
        return fetch(j, j2, Optional.empty(), rel);
    }

    public Fetch fetch(long j, long j2, Rel.Remap remap, Rel rel) {
        return fetch(j, j2, Optional.of(remap), rel);
    }

    private Fetch fetch(long j, long j2, Optional<Rel.Remap> optional, Rel rel) {
        return Fetch.builder().offset(j).count(j2).input(rel).remap(optional).build();
    }

    public Filter filter(Function<Rel, Expression> function, Rel rel) {
        return filter(function, Optional.empty(), rel);
    }

    public Filter filter(Function<Rel, Expression> function, Rel.Remap remap, Rel rel) {
        return filter(function, Optional.of(remap), rel);
    }

    private Filter filter(Function<Rel, Expression> function, Optional<Rel.Remap> optional, Rel rel) {
        return Filter.builder().input(rel).condition(function.apply(rel)).remap(optional).build();
    }

    public Join innerJoin(Function<JoinInput, Expression> function, Rel rel, Rel rel2) {
        return join(function, Join.JoinType.INNER, rel, rel2);
    }

    public Join innerJoin(Function<JoinInput, Expression> function, Rel.Remap remap, Rel rel, Rel rel2) {
        return join(function, Join.JoinType.INNER, remap, rel, rel2);
    }

    public Join join(Function<JoinInput, Expression> function, Join.JoinType joinType, Rel rel, Rel rel2) {
        return join(function, joinType, Optional.empty(), rel, rel2);
    }

    public Join join(Function<JoinInput, Expression> function, Join.JoinType joinType, Rel.Remap remap, Rel rel, Rel rel2) {
        return join(function, joinType, Optional.of(remap), rel, rel2);
    }

    private Join join(Function<JoinInput, Expression> function, Join.JoinType joinType, Optional<Rel.Remap> optional, Rel rel, Rel rel2) {
        return Join.builder().left(rel).right(rel2).condition(function.apply(new JoinInput(rel, rel2))).joinType(joinType).remap(optional).build();
    }

    public HashJoin hashJoin(List<Integer> list, List<Integer> list2, HashJoin.JoinType joinType, Rel rel, Rel rel2) {
        return hashJoin(list, list2, joinType, Optional.empty(), rel, rel2);
    }

    public HashJoin hashJoin(List<Integer> list, List<Integer> list2, HashJoin.JoinType joinType, Optional<Rel.Remap> optional, Rel rel, Rel rel2) {
        return HashJoin.builder().left(rel).right(rel2).leftKeys(fieldReferences(rel, list.stream().mapToInt((v0) -> {
            return v0.intValue();
        }).toArray())).rightKeys(fieldReferences(rel2, list2.stream().mapToInt((v0) -> {
            return v0.intValue();
        }).toArray())).joinType(joinType).remap(optional).build();
    }

    public MergeJoin mergeJoin(List<Integer> list, List<Integer> list2, MergeJoin.JoinType joinType, Rel rel, Rel rel2) {
        return mergeJoin(list, list2, joinType, Optional.empty(), rel, rel2);
    }

    public MergeJoin mergeJoin(List<Integer> list, List<Integer> list2, MergeJoin.JoinType joinType, Optional<Rel.Remap> optional, Rel rel, Rel rel2) {
        return MergeJoin.builder().left(rel).right(rel2).leftKeys(fieldReferences(rel, list.stream().mapToInt((v0) -> {
            return v0.intValue();
        }).toArray())).rightKeys(fieldReferences(rel2, list2.stream().mapToInt((v0) -> {
            return v0.intValue();
        }).toArray())).joinType(joinType).remap(optional).build();
    }

    public NestedLoopJoin nestedLoopJoin(Function<JoinInput, Expression> function, NestedLoopJoin.JoinType joinType, Rel rel, Rel rel2) {
        return nestedLoopJoin(function, joinType, Optional.empty(), rel, rel2);
    }

    private NestedLoopJoin nestedLoopJoin(Function<JoinInput, Expression> function, NestedLoopJoin.JoinType joinType, Optional<Rel.Remap> optional, Rel rel, Rel rel2) {
        return NestedLoopJoin.builder().left(rel).right(rel2).condition(function.apply(new JoinInput(rel, rel2))).joinType(joinType).remap(optional).build();
    }

    public NamedScan namedScan(Iterable<String> iterable, Iterable<String> iterable2, Iterable<Type> iterable3) {
        return namedScan(iterable, iterable2, iterable3, Optional.empty());
    }

    public NamedScan namedScan(Iterable<String> iterable, Iterable<String> iterable2, Iterable<Type> iterable3, Rel.Remap remap) {
        return namedScan(iterable, iterable2, iterable3, Optional.of(remap));
    }

    private NamedScan namedScan(Iterable<String> iterable, Iterable<String> iterable2, Iterable<Type> iterable3, Optional<Rel.Remap> optional) {
        return NamedScan.builder().names(iterable).initialSchema(NamedStruct.of(iterable2, Type.Struct.builder().addAllFields(iterable3).nullable(false).build())).remap(optional).build();
    }

    public Project project(Function<Rel, Iterable<? extends Expression>> function, Rel rel) {
        return project(function, Optional.empty(), rel);
    }

    public Project project(Function<Rel, Iterable<? extends Expression>> function, Rel.Remap remap, Rel rel) {
        return project(function, Optional.of(remap), rel);
    }

    private Project project(Function<Rel, Iterable<? extends Expression>> function, Optional<Rel.Remap> optional, Rel rel) {
        return Project.builder().input(rel).expressions(function.apply(rel)).remap(optional).build();
    }

    public Set set(Set.SetOp setOp, Rel... relArr) {
        return set(setOp, Optional.empty(), relArr);
    }

    public Set set(Set.SetOp setOp, Rel.Remap remap, Rel... relArr) {
        return set(setOp, Optional.of(remap), relArr);
    }

    private Set set(Set.SetOp setOp, Optional<Rel.Remap> optional, Rel... relArr) {
        return Set.builder().setOp(setOp).remap(optional).addAllInputs(Arrays.asList(relArr)).build();
    }

    public Sort sort(Function<Rel, Iterable<? extends Expression.SortField>> function, Rel rel) {
        return sort(function, Optional.empty(), rel);
    }

    public Sort sort(Function<Rel, Iterable<? extends Expression.SortField>> function, Rel.Remap remap, Rel rel) {
        return sort(function, Optional.of(remap), rel);
    }

    private Sort sort(Function<Rel, Iterable<? extends Expression.SortField>> function, Optional<Rel.Remap> optional, Rel rel) {
        return Sort.builder().input(rel).sortFields(function.apply(rel)).remap(optional).build();
    }

    public Expression.BoolLiteral bool(boolean z) {
        return Expression.BoolLiteral.builder().value(Boolean.valueOf(z)).build();
    }

    public Expression.I32Literal i32(int i) {
        return Expression.I32Literal.builder().value(i).build();
    }

    public Expression cast(Expression expression, Type type) {
        return ImmutableExpression.Cast.builder().input(expression).type(type).failureBehavior(Expression.FailureBehavior.UNSPECIFIED).build();
    }

    public FieldReference fieldReference(Rel rel, int i) {
        return ImmutableFieldReference.newInputRelReference(i, rel);
    }

    public List<FieldReference> fieldReferences(Rel rel, int... iArr) {
        return (List) Arrays.stream(iArr).mapToObj(i -> {
            return fieldReference(rel, i);
        }).collect(Collectors.toList());
    }

    public FieldReference fieldReference(List<Rel> list, int i) {
        return ImmutableFieldReference.newInputRelReference(i, list);
    }

    public List<FieldReference> fieldReferences(List<Rel> list, int... iArr) {
        return (List) Arrays.stream(iArr).mapToObj(i -> {
            return fieldReference((List<Rel>) list, i);
        }).collect(Collectors.toList());
    }

    public Expression.IfThen ifThen(Iterable<? extends Expression.IfClause> iterable, Expression expression) {
        return Expression.IfThen.builder().addAllIfClauses(iterable).elseClause(expression).build();
    }

    public Expression.IfClause ifClause(Expression expression, Expression expression2) {
        return Expression.IfClause.builder().condition(expression).then(expression2).build();
    }

    public Expression singleOrList(Expression expression, Expression... expressionArr) {
        return ImmutableExpression.SingleOrList.builder().condition(expression).addOptions(expressionArr).build();
    }

    public List<Expression.SortField> sortFields(Rel rel, int... iArr) {
        return (List) Arrays.stream(iArr).mapToObj(i -> {
            return Expression.SortField.builder().expr(ImmutableFieldReference.newInputRelReference(i, rel)).direction(Expression.SortDirection.ASC_NULLS_LAST).build();
        }).collect(Collectors.toList());
    }

    public Expression.SwitchClause switchClause(Expression.Literal literal, Expression expression) {
        return Expression.SwitchClause.builder().condition(literal).then(expression).build();
    }

    public ImmutableExpression.Switch switchExpression(Expression expression, Iterable<? extends Expression.SwitchClause> iterable, Expression expression2) {
        return ImmutableExpression.Switch.builder().match(expression).addAllSwitchClauses(iterable).defaultClause(expression2).build();
    }

    public AggregateFunctionInvocation aggregateFn(String str, String str2, Type type, Expression... expressionArr) {
        return AggregateFunctionInvocation.builder().arguments((Iterable) Arrays.stream(expressionArr).collect(Collectors.toList())).outputType(type).declaration(this.extensions.getAggregateFunction(SimpleExtension.FunctionAnchor.of(str, str2))).aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT).invocation(Expression.AggregationInvocation.ALL).build();
    }

    public Aggregate.Grouping grouping(Rel rel, int... iArr) {
        return Aggregate.Grouping.builder().addAllExpressions(fieldReferences(rel, iArr)).build();
    }

    public AggregateFunctionInvocation count(Rel rel, int i) {
        return AggregateFunctionInvocation.builder().arguments(fieldReferences(rel, i)).outputType(R.I64).declaration(this.extensions.getAggregateFunction(SimpleExtension.FunctionAnchor.of(FUNCTIONS_AGGREGATE_GENERIC, "count:any"))).aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT).invocation(Expression.AggregationInvocation.ALL).build();
    }

    public AggregateFunctionInvocation min(Rel rel, int i) {
        return singleArgumentArithmeticAggregate(rel, i, "min", TypeCreator.asNullable(rel.getRecordType().fields().get(i)));
    }

    public AggregateFunctionInvocation max(Rel rel, int i) {
        return singleArgumentArithmeticAggregate(rel, i, "max", TypeCreator.asNullable(rel.getRecordType().fields().get(i)));
    }

    public AggregateFunctionInvocation avg(Rel rel, int i) {
        return singleArgumentArithmeticAggregate(rel, i, "avg", TypeCreator.asNullable(rel.getRecordType().fields().get(i)));
    }

    public AggregateFunctionInvocation sum(Rel rel, int i) {
        return singleArgumentArithmeticAggregate(rel, i, "sum", TypeCreator.asNullable(rel.getRecordType().fields().get(i)));
    }

    public AggregateFunctionInvocation sum0(Rel rel, int i) {
        return singleArgumentArithmeticAggregate(rel, i, "sum0", R.I64);
    }

    private AggregateFunctionInvocation singleArgumentArithmeticAggregate(Rel rel, int i, String str, Type type) {
        return AggregateFunctionInvocation.builder().arguments(fieldReferences(rel, i)).outputType(type).declaration(this.extensions.getAggregateFunction(SimpleExtension.FunctionAnchor.of(FUNCTIONS_ARITHMETIC, String.format("%s:%s", str, (String) rel.getRecordType().fields().get(i).accept(ToTypeString.INSTANCE))))).aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT).invocation(Expression.AggregationInvocation.ALL).build();
    }

    public Expression.ScalarFunctionInvocation equal(Expression expression, Expression expression2) {
        return scalarFn(FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, expression, expression2);
    }

    public Expression.ScalarFunctionInvocation or(Expression... expressionArr) {
        return scalarFn(FUNCTIONS_BOOLEAN, "or:bool", Arrays.stream(expressionArr).anyMatch(expression -> {
            return expression.getType().nullable();
        }) ? N.BOOLEAN : R.BOOLEAN, expressionArr);
    }

    public Expression.ScalarFunctionInvocation scalarFn(String str, String str2, Type type, Expression... expressionArr) {
        return Expression.ScalarFunctionInvocation.builder().declaration(this.extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(str, str2))).outputType(type).arguments((Iterable) Arrays.stream(expressionArr).collect(Collectors.toList())).build();
    }

    public Type.UserDefined userDefinedType(String str, String str2) {
        return ImmutableType.UserDefined.builder().uri(str).name(str2).nullable(false).build();
    }

    public Plan.Root root(Rel rel) {
        return ImmutableRoot.builder().input(rel).build();
    }

    public Plan plan(Plan.Root root) {
        return ImmutablePlan.builder().addRoots(root).build();
    }

    public Rel.Remap remap(Integer... numArr) {
        return Rel.Remap.of(Arrays.asList(numArr));
    }
}
