package com.linkedin.feathr.offline.evaluator.aggregation;

import com.linkedin.feathr.common.FeatureTypeConfig;
import com.linkedin.feathr.common.FeatureValue;
import com.linkedin.feathr.compute.Aggregation;
import com.linkedin.feathr.compute.AnyNode;
import com.linkedin.feathr.exception.ErrorLabel;
import com.linkedin.feathr.exception.FeathrConfigException;
import com.linkedin.feathr.offline.anchored.WindowTimeUnit$;
import com.linkedin.feathr.offline.client.NOT_VISITED$;
import com.linkedin.feathr.offline.client.VISITED$;
import com.linkedin.feathr.offline.client.VisitedState;
import com.linkedin.feathr.offline.config.JoinConfigSettings;
import com.linkedin.feathr.offline.config.JoinTimeSetting;
import com.linkedin.feathr.offline.evaluator.NodeEvaluator;
import com.linkedin.feathr.offline.graph.DataframeAndColumnMetadata;
import com.linkedin.feathr.offline.graph.DataframeAndColumnMetadata$;
import com.linkedin.feathr.offline.graph.FCMGraphTraverser;
import com.linkedin.feathr.offline.graph.NodeGrouper$;
import com.linkedin.feathr.offline.graph.NodeUtils$;
import com.linkedin.feathr.offline.job.FeatureTransformation$;
import com.linkedin.feathr.offline.source.accessor.DataPathHandler;
import com.linkedin.feathr.offline.swa.SlidingWindowFeatureUtils$;
import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter$;
import com.linkedin.feathr.offline.transformation.FeatureColumnFormat$;
import com.linkedin.feathr.swj.FactData;
import com.linkedin.feathr.swj.GroupBySpec;
import com.linkedin.feathr.swj.LabelData;
import com.linkedin.feathr.swj.LateralViewParams;
import com.linkedin.feathr.swj.SlidingWindowFeature;
import com.linkedin.feathr.swj.SlidingWindowJoin$;
import com.linkedin.feathr.swj.WindowSpec;
import com.linkedin.feathr.swj.aggregate.AggregationSpec;
import com.linkedin.feathr.swj.aggregate.AggregationType$;
import com.linkedin.feathr.swj.aggregate.AvgAggregate;
import com.linkedin.feathr.swj.aggregate.AvgAggregate$;
import com.linkedin.feathr.swj.aggregate.AvgPoolingAggregate;
import com.linkedin.feathr.swj.aggregate.AvgPoolingAggregate$;
import com.linkedin.feathr.swj.aggregate.CountAggregate;
import com.linkedin.feathr.swj.aggregate.LatestAggregate;
import com.linkedin.feathr.swj.aggregate.MaxAggregate;
import com.linkedin.feathr.swj.aggregate.MaxPoolingAggregate;
import com.linkedin.feathr.swj.aggregate.MinAggregate;
import com.linkedin.feathr.swj.aggregate.MinPoolingAggregate;
import com.linkedin.feathr.swj.aggregate.SumAggregate;
import java.time.Duration;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import scala.Array$;
import scala.Enumeration;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.IterableLike;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.$colon;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Set;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.Buffer$;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.Map;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;

/* compiled from: AggregationNodeEvaluator.scala */
/* loaded from: input_file:com/linkedin/feathr/offline/evaluator/aggregation/AggregationNodeEvaluator$.class */
public final class AggregationNodeEvaluator$ implements NodeEvaluator {
    public static AggregationNodeEvaluator$ MODULE$;

    static {
        new AggregationNodeEvaluator$();
    }

    private LabelData getLabelData(Aggregation aggregation, Option<JoinConfigSettings> option, Dataset<Row> dataset, Map<Object, DataframeAndColumnMetadata> map) {
        return new LabelData(dataset, (Buffer) ((Buffer) ((TraversableLike) JavaConverters$.MODULE$.asScalaBufferConverter(aggregation.getConcreteKey().getKey()).asScala()).flatMap(num -> {
            return ((DataframeAndColumnMetadata) map.apply(BoxesRunTime.boxToInteger(Predef$.MODULE$.Integer2int(num)))).keyExpression();
        }, Buffer$.MODULE$.canBuildFrom())).map(str -> {
            return new StringBuilder(17).append("CAST (").append(str).append(" AS string)").toString();
        }, Buffer$.MODULE$.canBuildFrom()), (option.isDefined() && ((JoinConfigSettings) option.get()).joinTimeSetting().isDefined() && ((JoinTimeSetting) ((JoinConfigSettings) option.get()).joinTimeSetting().get()).useLatestFeatureData()) ? "unix_timestamp()" : SlidingWindowFeatureUtils$.MODULE$.constructTimeStampExpr(((JoinTimeSetting) ((JoinConfigSettings) option.get()).joinTimeSetting().get()).timestampColumn().name(), ((JoinTimeSetting) ((JoinConfigSettings) option.get()).joinTimeSetting().get()).timestampColumn().format(), SlidingWindowFeatureUtils$.MODULE$.constructTimeStampExpr$default$3()));
    }

    private Option<LateralViewParams> getLateralViewParams(Aggregation aggregation) {
        Some some;
        Some some2;
        String str = (String) aggregation.getFunction().getParameters().get("lateral_view_expression_0");
        if (str != null) {
            some = new Some(str);
        } else {
            if (str != null) {
                throw new MatchError(str);
            }
            some = None$.MODULE$;
        }
        Some some3 = some;
        String str2 = (String) aggregation.getFunction().getParameters().get("lateral_view_table_alias_0");
        if (str2 != null) {
            some2 = new Some(str2);
        } else {
            if (str2 != null) {
                throw new MatchError(str2);
            }
            some2 = None$.MODULE$;
        }
        Some some4 = some2;
        return (some3.isDefined() && some4.isDefined()) ? new Some(new LateralViewParams((String) some3.get(), (String) some4.get(), None$.MODULE$)) : None$.MODULE$;
    }

    private AggregationSpec getAggSpec(Enumeration.Value value, String str) {
        AggregationSpec avgPoolingAggregate;
        Enumeration.Value SUM = AggregationType$.MODULE$.SUM();
        if (SUM != null ? !SUM.equals(value) : value != null) {
            Enumeration.Value COUNT = AggregationType$.MODULE$.COUNT();
            if (COUNT != null ? !COUNT.equals(value) : value != null) {
                Enumeration.Value AVG = AggregationType$.MODULE$.AVG();
                if (AVG != null ? !AVG.equals(value) : value != null) {
                    Enumeration.Value MAX = AggregationType$.MODULE$.MAX();
                    if (MAX != null ? !MAX.equals(value) : value != null) {
                        Enumeration.Value MIN = AggregationType$.MODULE$.MIN();
                        if (MIN != null ? !MIN.equals(value) : value != null) {
                            Enumeration.Value LATEST = AggregationType$.MODULE$.LATEST();
                            if (LATEST != null ? !LATEST.equals(value) : value != null) {
                                Enumeration.Value MAX_POOLING = AggregationType$.MODULE$.MAX_POOLING();
                                if (MAX_POOLING != null ? !MAX_POOLING.equals(value) : value != null) {
                                    Enumeration.Value MIN_POOLING = AggregationType$.MODULE$.MIN_POOLING();
                                    if (MIN_POOLING != null ? !MIN_POOLING.equals(value) : value != null) {
                                        Enumeration.Value AVG_POOLING = AggregationType$.MODULE$.AVG_POOLING();
                                        if (AVG_POOLING != null ? !AVG_POOLING.equals(value) : value != null) {
                                            throw new MatchError(value);
                                        }
                                        avgPoolingAggregate = new AvgPoolingAggregate(str, AvgPoolingAggregate$.MODULE$.$lessinit$greater$default$2());
                                    } else {
                                        avgPoolingAggregate = new MinPoolingAggregate(str);
                                    }
                                } else {
                                    avgPoolingAggregate = new MaxPoolingAggregate(str);
                                }
                            } else {
                                avgPoolingAggregate = new LatestAggregate(str);
                            }
                        } else {
                            avgPoolingAggregate = new MinAggregate(str);
                        }
                    } else {
                        avgPoolingAggregate = new MaxAggregate(str);
                    }
                } else {
                    avgPoolingAggregate = new AvgAggregate(str, AvgAggregate$.MODULE$.$lessinit$greater$default$2());
                }
            } else {
                avgPoolingAggregate = new CountAggregate(new StringBuilder(40).append("CASE WHEN ").append(str).append(" IS NOT NULL THEN 1 ELSE 0 END").toString());
            }
        } else {
            avgPoolingAggregate = new SumAggregate(str);
        }
        return avgPoolingAggregate;
    }

    private Duration getSimTimeDelay(String str, Option<JoinConfigSettings> option, scala.collection.immutable.Map<String, String> map) {
        if (!map.contains(str)) {
            return (option.isDefined() && ((JoinConfigSettings) option.get()).joinTimeSetting().isDefined() && ((JoinTimeSetting) ((JoinConfigSettings) option.get()).joinTimeSetting().get()).simulateTimeDelay().isDefined()) ? (Duration) ((JoinTimeSetting) ((JoinConfigSettings) option.get()).joinTimeSetting().get()).simulateTimeDelay().get() : Duration.ZERO;
        }
        if (option.isEmpty() || ((JoinConfigSettings) option.get()).joinTimeSetting().isEmpty() || ((JoinTimeSetting) ((JoinConfigSettings) option.get()).joinTimeSetting().get()).simulateTimeDelay().isEmpty()) {
            throw new FeathrConfigException(ErrorLabel.FEATHR_USER_ERROR, "overrideTimeDelay cannot be defined without setting a simulateTimeDelay in the joinTimeSettings");
        }
        return WindowTimeUnit$.MODULE$.parseWindowTime((String) map.apply(str));
    }

    private List<FactData> getFactDataSet(scala.collection.immutable.Map<Integer, AnyNode> map, scala.collection.immutable.Map<Integer, Seq<Integer>> map2, Aggregation aggregation, Map<Object, DataframeAndColumnMetadata> map3, HashMap<String, Enumeration.Value> hashMap, Option<JoinConfigSettings> option, scala.collection.immutable.Map<String, String> map4, scala.collection.immutable.Map<Integer, String> map5) {
        return (List) ((List) ((Seq) ((Seq) map2.apply(aggregation.getId())).map(num -> {
            return (AnyNode) map.apply(num);
        }, Seq$.MODULE$.canBuildFrom())).groupBy(anyNode -> {
            return new Tuple3(((DataframeAndColumnMetadata) map3.apply(BoxesRunTime.boxToInteger(Predef$.MODULE$.Integer2int(anyNode.getAggregation().getInput().getId())))).dataSource(), ((DataframeAndColumnMetadata) map3.apply(BoxesRunTime.boxToInteger(Predef$.MODULE$.Integer2int(anyNode.getAggregation().getInput().getId())))).keyExpression(), MODULE$.getLateralViewParams(anyNode.getAggregation()));
        }).values().toList().sortBy(seq -> {
            return BoxesRunTime.boxToInteger(seq.size());
        }, Ordering$Int$.MODULE$)).reverse().map(seq2 -> {
            Aggregation aggregation2 = ((AnyNode) ((IterableLike) seq2.filter(anyNode2 -> {
                return BoxesRunTime.boxToBoolean($anonfun$getFactDataSet$5(map3, anyNode2));
            })).head()).getAggregation();
            Dataset<Row> df = ((DataframeAndColumnMetadata) map3.apply(BoxesRunTime.boxToInteger(Predef$.MODULE$.Integer2int(aggregation2.getInput().getId())))).df();
            Seq<String> keyExpression = ((DataframeAndColumnMetadata) map3.apply(BoxesRunTime.boxToInteger(Predef$.MODULE$.Integer2int(aggregation2.getInput().getId())))).keyExpression();
            String str = (String) ((DataframeAndColumnMetadata) map3.apply(BoxesRunTime.boxToInteger(Predef$.MODULE$.Integer2int(aggregation2.getInput().getId())))).timestampColumn().get();
            Seq seq2 = (Seq) keyExpression.map(str2 -> {
                return new StringBuilder(17).append("CAST (").append(str2).append(" AS string)").toString();
            }, Seq$.MODULE$.canBuildFrom());
            Option<LateralViewParams> lateralViewParams = MODULE$.getLateralViewParams(aggregation2);
            return new FactData(df, seq2, str, ((Seq) seq2.map(anyNode3 -> {
                Some some;
                Some some2;
                Some some3;
                Aggregation aggregation3 = anyNode3.getAggregation();
                String str3 = (String) map5.apply(aggregation3.getId());
                Enumeration.Value withName = AggregationType$.MODULE$.withName((String) aggregation3.getFunction().getParameters().get("aggregation_type"));
                String str4 = (String) aggregation3.getFunction().getParameters().get("target_column");
                Tuple2 tuple2 = str4.contains(FeatureTransformation$.MODULE$.USER_FACING_MULTI_DIM_FDS_TENSOR_UDF_NAME()) ? new Tuple2(FeatureTransformation$.MODULE$.parseMultiDimTensorExpr(str4), FeatureColumnFormat$.MODULE$.FDS_TENSOR()) : new Tuple2(str4, FeatureColumnFormat$.MODULE$.RAW());
                AggregationSpec aggSpec = MODULE$.getAggSpec(withName, (String) tuple2._1());
                Duration parse = Duration.parse((CharSequence) aggregation3.getFunction().getParameters().get("window_size"));
                Duration simTimeDelay = MODULE$.getSimTimeDelay(str3, option, map4);
                String str5 = (String) aggregation3.getFunction().getParameters().get("filter_expression");
                if (str5 != null) {
                    some = new Some(str5);
                } else {
                    if (str5 != null) {
                        throw new MatchError(str5);
                    }
                    some = None$.MODULE$;
                }
                Some some4 = some;
                String str6 = (String) aggregation3.getFunction().getParameters().get("group_by_expression");
                if (str6 != null) {
                    some2 = new Some(str6);
                } else {
                    if (str6 != null) {
                        throw new MatchError(str6);
                    }
                    some2 = None$.MODULE$;
                }
                Some some5 = some2;
                String str7 = (String) aggregation3.getFunction().getParameters().get("max_number_groups");
                if (str7 != null) {
                    some3 = new Some(BoxesRunTime.boxToInteger(new StringOps(Predef$.MODULE$.augmentString(str7)).toInt()));
                } else {
                    if (str7 != null) {
                        throw new MatchError(str7);
                    }
                    some3 = new Some(BoxesRunTime.boxToInteger(0));
                }
                Some some6 = some5.isDefined() ? new Some(new GroupBySpec((String) some5.get(), BoxesRunTime.unboxToInt(some3.get()))) : None$.MODULE$;
                hashMap.update(str3, tuple2._2());
                return new SlidingWindowFeature(str3, aggSpec, new WindowSpec(parse, simTimeDelay), some4, some6, lateralViewParams);
            }, Seq$.MODULE$.canBuildFrom())).toList());
        }, List$.MODULE$.canBuildFrom());
    }

    @Override // com.linkedin.feathr.offline.evaluator.NodeEvaluator
    public Dataset<Row> batchEvaluate(Seq<AnyNode> seq, FCMGraphTraverser fCMGraphTraverser, Dataset<Row> dataset, List<DataPathHandler> list) {
        HashMap<Integer, Seq<Integer>> groupSWANodes = NodeGrouper$.MODULE$.groupSWANodes(seq);
        scala.collection.immutable.Map map = ((TraversableOnce) ((TraversableLike) fCMGraphTraverser.nodes().filter(anyNode -> {
            return BoxesRunTime.boxToBoolean(anyNode.isAggregation());
        })).map(anyNode2 -> {
            return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(anyNode2.getAggregation().getId()), anyNode2);
        }, Buffer$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms());
        HashMap<String, Enumeration.Value> featureColumnFormatsMap = fCMGraphTraverser.featureColumnFormatsMap();
        scala.collection.immutable.Map<String, FeatureValue> defaultConverter = NodeUtils$.MODULE$.getDefaultConverter(seq);
        scala.collection.immutable.Map<String, FeatureTypeConfig> featureTypeConfigsMap = NodeUtils$.MODULE$.getFeatureTypeConfigsMap(seq);
        ObjectRef create = ObjectRef.create(dataset);
        VisitedState[] visitedStateArr = (VisitedState[]) Array$.MODULE$.fill(fCMGraphTraverser.nodes().length(), () -> {
            return NOT_VISITED$.MODULE$;
        }, ClassTag$.MODULE$.apply(VisitedState.class));
        ((List) groupSWANodes.values().toList().sortBy(seq2 -> {
            return BoxesRunTime.boxToInteger(seq2.size());
        }, Ordering$Int$.MODULE$)).reverse().map(seq3 -> {
            AnyNode anyNode3 = (AnyNode) map.apply(seq3.head());
            VisitedState visitedState = visitedStateArr[Predef$.MODULE$.Integer2int(anyNode3.getAggregation().getId())];
            VISITED$ visited$ = VISITED$.MODULE$;
            if (visitedState != null ? visitedState.equals(visited$) : visited$ == null) {
                return BoxedUnit.UNIT;
            }
            create.elem = SlidingWindowJoin$.MODULE$.join(MODULE$.getLabelData(anyNode3.getAggregation(), fCMGraphTraverser.timeConfigSettings().timeConfigSettings(), (Dataset) create.elem, fCMGraphTraverser.nodeIdToDataframeAndColumnMetadataMap()), MODULE$.getFactDataSet(map, groupSWANodes.toMap(Predef$.MODULE$.$conforms()), anyNode3.getAggregation(), fCMGraphTraverser.nodeIdToDataframeAndColumnMetadataMap(), featureColumnFormatsMap, fCMGraphTraverser.timeConfigSettings().timeConfigSettings(), fCMGraphTraverser.timeConfigSettings().featuresToTimeDelayMap(), fCMGraphTraverser.nodeIdToFeatureName()), SlidingWindowJoin$.MODULE$.join$default$3());
            return ((Seq) groupSWANodes.apply(anyNode3.getAggregation().getId())).map(num -> {
                $anonfun$batchEvaluate$6(fCMGraphTraverser, create, featureColumnFormatsMap, featureTypeConfigsMap, defaultConverter, visitedStateArr, num);
                return BoxedUnit.UNIT;
            }, Seq$.MODULE$.canBuildFrom());
        }, List$.MODULE$.canBuildFrom());
        return (Dataset) create.elem;
    }

    @Override // com.linkedin.feathr.offline.evaluator.NodeEvaluator
    public Dataset<Row> evaluate(AnyNode anyNode, FCMGraphTraverser fCMGraphTraverser, Dataset<Row> dataset, List<DataPathHandler> list) {
        return batchEvaluate((Seq) new $colon.colon(anyNode, Nil$.MODULE$), fCMGraphTraverser, dataset, list);
    }

    public static final /* synthetic */ boolean $anonfun$getFactDataSet$5(Map map, AnyNode anyNode) {
        return map.contains(BoxesRunTime.boxToInteger(Predef$.MODULE$.Integer2int(anyNode.getAggregation().getInput().getId())));
    }

    public static final /* synthetic */ void $anonfun$batchEvaluate$6(FCMGraphTraverser fCMGraphTraverser, ObjectRef objectRef, HashMap hashMap, scala.collection.immutable.Map map, scala.collection.immutable.Map map2, VisitedState[] visitedStateArr, Integer num) {
        String str = (String) fCMGraphTraverser.nodeIdToFeatureName().apply(num);
        objectRef.elem = SlidingWindowFeatureUtils$.MODULE$.convertSWADFToFDS((Dataset) objectRef.elem, (Set) Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapRefArray(new String[]{str})), hashMap.toMap(Predef$.MODULE$.$conforms()), map).df();
        hashMap.update(str, FeatureColumnFormat$.MODULE$.FDS_TENSOR());
        objectRef.elem = DataFrameDefaultValueSubstituter$.MODULE$.substituteDefaults2((Dataset<Row>) objectRef.elem, (Seq<String>) new $colon.colon(str, Nil$.MODULE$), (scala.collection.immutable.Map<String, FeatureValue>) map2, (scala.collection.immutable.Map<String, FeatureTypeConfig>) map, fCMGraphTraverser.ss(), DataFrameDefaultValueSubstituter$.MODULE$.substituteDefaults$default$6());
        objectRef.elem = ((Dataset) objectRef.elem).withColumnRenamed(str, new StringBuilder(9).append(str).append("__dummy__").toString());
        objectRef.elem = ((Dataset) objectRef.elem).withColumn(str, functions$.MODULE$.col(new StringBuilder(9).append(str).append("__dummy__").toString()));
        objectRef.elem = ((Dataset) objectRef.elem).drop(new StringBuilder(9).append(str).append("__dummy__").toString());
        fCMGraphTraverser.nodeIdToDataframeAndColumnMetadataMap().update(BoxesRunTime.boxToInteger(Predef$.MODULE$.Integer2int(num)), new DataframeAndColumnMetadata((Dataset) objectRef.elem, Nil$.MODULE$, new Some(str), DataframeAndColumnMetadata$.MODULE$.apply$default$4(), DataframeAndColumnMetadata$.MODULE$.apply$default$5()));
        visitedStateArr[Predef$.MODULE$.Integer2int(num)] = VISITED$.MODULE$;
    }

    private AggregationNodeEvaluator$() {
        MODULE$ = this;
    }
}
