package com.linkedin.coral.pig.rel2pig.rel;

import com.linkedin.coral.pig.rel2pig.exceptions.UnsupportedRexCallException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;

/* loaded from: input_file:com/linkedin/coral/pig/rel2pig/rel/PigLogicalAggregate.class */
public class PigLogicalAggregate {
    private static final String GROUP_BY_TEMPLATE = "%s = GROUP %s BY (%s);";
    private static final String GROUP_ALL_TEMPLATE = "%s = GROUP %s ALL;";
    private static final String AGGREGATE_TEMPLATE = "%s = FOREACH %s GENERATE %s;";
    private static final String FIELD_TEMPLATE = "%s AS %s";
    private static final String AGGREGATE_CALL_TEMPLATE = "%s(%s) AS %s";

    private PigLogicalAggregate() {
    }

    public static String getScript(LogicalAggregate logicalAggregate, String str, String str2) {
        if (logicalAggregate.getGroupSets().size() != 1) {
            throw new UnsupportedRexCallException("Only grouping sets of size 1 is supported");
        }
        return String.join("\n", getGroupByStatement(logicalAggregate, str, str2), getForEachStatement(logicalAggregate, str, str2, str2));
    }

    private static String getGroupByStatement(LogicalAggregate logicalAggregate, String str, String str2) {
        List list = logicalAggregate.getGroupSet().toList();
        if (list.isEmpty()) {
            return String.format(GROUP_ALL_TEMPLATE, str, str2);
        }
        List<String> outputFieldNames = PigRelUtils.getOutputFieldNames(logicalAggregate.getInput());
        Stream stream = list.stream();
        outputFieldNames.getClass();
        return String.format(GROUP_BY_TEMPLATE, str, str2, (String) stream.map((v1) -> {
            return r1.get(v1);
        }).collect(Collectors.joining(", ")));
    }

    private static String getForEachStatement(LogicalAggregate logicalAggregate, String str, String str2, String str3) {
        List<String> outputFieldNames = PigRelUtils.getOutputFieldNames(logicalAggregate);
        List<String> outputFieldNames2 = PigRelUtils.getOutputFieldNames(logicalAggregate.getInput());
        String aggregateFunctionCalls = getAggregateFunctionCalls(logicalAggregate, outputFieldNames, outputFieldNames2, str3);
        List list = logicalAggregate.getGroupSet().toList();
        if (!list.isEmpty()) {
            aggregateFunctionCalls = String.join(", ", getGroupSetFields(list, outputFieldNames, outputFieldNames2), aggregateFunctionCalls);
        }
        return String.format(AGGREGATE_TEMPLATE, str, str2, aggregateFunctionCalls);
    }

    private static String getGroupSetFields(List<Integer> list, List<String> list2, List<String> list3) {
        return list.size() == 1 ? String.format(FIELD_TEMPLATE, "group", list2.get(list.get(0).intValue())) : (String) list.stream().map(num -> {
            return String.format(FIELD_TEMPLATE, String.format("group.%s", list3.get(num.intValue())), list2.get(num.intValue()));
        }).collect(Collectors.joining(", "));
    }

    private static String getAggregateFunctionCalls(LogicalAggregate logicalAggregate, List<String> list, List<String> list2, String str) {
        int size = logicalAggregate.getGroupSet().toList().size();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < logicalAggregate.getAggCallList().size(); i++) {
            AggregateCall aggregateCall = (AggregateCall) logicalAggregate.getAggCallList().get(i);
            String str2 = (String) aggregateCall.getArgList().stream().map(num -> {
                return String.format("%s.%s", str, list2.get(num.intValue()));
            }).collect(Collectors.joining(", "));
            if (str2.isEmpty()) {
                str2 = str;
            }
            arrayList.add(String.format(AGGREGATE_CALL_TEMPLATE, aggregateCall.getAggregation().getName().toUpperCase(), str2, list.get(size + i)));
        }
        return String.join(", ", arrayList);
    }
}
