package com.facebook.presto.tests.statistics;

import com.facebook.presto.cost.PlanNodeCost;
import com.facebook.presto.execution.StageInfo;
import com.facebook.presto.spi.statistics.Estimate;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.planPrinter.PlanNodeStats;
import com.facebook.presto.sql.planner.planPrinter.PlanNodeStatsSummarizer;
import com.facebook.presto.util.MoreMaps;
import com.google.common.collect.Maps;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/tests/statistics/MetricComparator.class */
public class MetricComparator {
    private final List<Metric> metrics = Arrays.asList(Metric.values());
    private final double tolerance = 0.1d;

    public List<MetricComparison> getMetricComparisons(Plan plan, StageInfo stageInfo) {
        return (List) this.metrics.stream().flatMap(metric -> {
            Map planNodeCosts = plan.getPlanNodeCosts();
            Map<PlanNodeId, PlanNodeCost> extractActualCosts = extractActualCosts(stageInfo);
            return planNodeCosts.entrySet().stream().map(entry -> {
                PlanNode planNodeForId = planNodeForId(plan, (PlanNodeId) entry.getKey());
                return createMetricComparison(metric, planNodeForId, (PlanNodeCost) entry.getValue(), Optional.ofNullable(extractActualCosts.get(planNodeForId.getId())));
            });
        }).collect(Collectors.toList());
    }

    private PlanNode planNodeForId(Plan plan, PlanNodeId planNodeId) {
        return PlanNodeSearcher.searchFrom(plan.getRoot()).where(planNode -> {
            return planNode.getId().equals(planNodeId);
        }).findOnlyElement();
    }

    private Map<PlanNodeId, PlanNodeCost> extractActualCosts(StageInfo stageInfo) {
        return Maps.transformValues(mergeStats(StageInfo.getAllStages(Optional.of(stageInfo)).stream().map(PlanNodeStatsSummarizer::aggregatePlanNodeStats)), this::toPlanNodeCost);
    }

    private Map<PlanNodeId, PlanNodeStats> mergeStats(Stream<Map<PlanNodeId, PlanNodeStats>> stream) {
        return MoreMaps.mergeMaps(stream, (planNodeStats, planNodeStats2) -> {
            throw new IllegalArgumentException("PlanNodeIds must be unique");
        });
    }

    private PlanNodeCost toPlanNodeCost(PlanNodeStats planNodeStats) {
        return PlanNodeCost.builder().setOutputRowCount(new Estimate(planNodeStats.getPlanNodeOutputPositions())).setOutputSizeInBytes(new Estimate(planNodeStats.getPlanNodeOutputDataSize().toBytes())).build();
    }

    private MetricComparison createMetricComparison(Metric metric, PlanNode planNode, PlanNodeCost planNodeCost, Optional<PlanNodeCost> optional) {
        return new MetricComparison(planNode, metric, asOptional(metric.getValue(planNodeCost)), optional.flatMap(planNodeCost2 -> {
            return asOptional(metric.getValue(planNodeCost2));
        }), 0.1d);
    }

    private Optional<Double> asOptional(Estimate estimate) {
        return estimate.isValueUnknown() ? Optional.empty() : Optional.of(Double.valueOf(estimate.getValue()));
    }
}
