package io.trino.execution.scheduler;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Multimap;
import com.google.common.math.Stats;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.annotation.NotThreadSafe;
import io.trino.execution.QueryManagerConfig;
import io.trino.execution.StageId;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.sql.planner.plan.PlanNodeId;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import org.weakref.jmx.Managed;

/* loaded from: input_file:io/trino/execution/scheduler/TaskDescriptorStorage.class */
public class TaskDescriptorStorage {
    private static final Logger log = Logger.get(TaskDescriptorStorage.class);
    private final long maxMemoryInBytes;

    @GuardedBy("this")
    private final Map<QueryId, TaskDescriptors> storages;

    @GuardedBy("this")
    private long reservedBytes;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/scheduler/TaskDescriptorStorage$TaskDescriptorKey.class */
    public static class TaskDescriptorKey {
        private final StageId stageId;
        private final int partitionId;

        private TaskDescriptorKey(StageId stageId, int i) {
            this.stageId = (StageId) Objects.requireNonNull(stageId, "stageId is null");
            this.partitionId = i;
        }

        public StageId getStageId() {
            return this.stageId;
        }

        public int getPartitionId() {
            return this.partitionId;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            TaskDescriptorKey taskDescriptorKey = (TaskDescriptorKey) obj;
            return this.partitionId == taskDescriptorKey.partitionId && Objects.equals(this.stageId, taskDescriptorKey.stageId);
        }

        public int hashCode() {
            return Objects.hash(this.stageId, Integer.valueOf(this.partitionId));
        }

        public String toString() {
            return MoreObjects.toStringHelper(this).add("stageId", this.stageId).add("partitionId", this.partitionId).toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    @NotThreadSafe
    /* loaded from: input_file:io/trino/execution/scheduler/TaskDescriptorStorage$TaskDescriptors.class */
    public static class TaskDescriptors {
        private final Map<TaskDescriptorKey, TaskDescriptor> descriptors = new HashMap();
        private long reservedBytes;
        private RuntimeException failure;

        private TaskDescriptors() {
        }

        public void put(StageId stageId, int i, TaskDescriptor taskDescriptor) {
            throwIfFailed();
            TaskDescriptorKey taskDescriptorKey = new TaskDescriptorKey(stageId, i);
            Preconditions.checkState(this.descriptors.putIfAbsent(taskDescriptorKey, taskDescriptor) == null, "task descriptor is already present for key %s ", taskDescriptorKey);
            this.reservedBytes += taskDescriptor.getRetainedSizeInBytes();
        }

        public TaskDescriptor get(StageId stageId, int i) {
            throwIfFailed();
            TaskDescriptorKey taskDescriptorKey = new TaskDescriptorKey(stageId, i);
            TaskDescriptor taskDescriptor = this.descriptors.get(taskDescriptorKey);
            if (taskDescriptor == null) {
                throw new NoSuchElementException(String.format("descriptor not found for key %s", taskDescriptorKey));
            }
            return taskDescriptor;
        }

        public void remove(StageId stageId, int i) {
            throwIfFailed();
            TaskDescriptorKey taskDescriptorKey = new TaskDescriptorKey(stageId, i);
            TaskDescriptor remove = this.descriptors.remove(taskDescriptorKey);
            if (remove == null) {
                throw new NoSuchElementException(String.format("descriptor not found for key %s", taskDescriptorKey));
            }
            this.reservedBytes -= remove.getRetainedSizeInBytes();
        }

        public long getReservedBytes() {
            return this.reservedBytes;
        }

        private String getDebugInfo() {
            return String.valueOf((Map) ((Multimap) this.descriptors.entrySet().stream().collect(ImmutableSetMultimap.toImmutableSetMultimap(entry -> {
                return ((TaskDescriptorKey) entry.getKey()).getStageId();
            }, (v0) -> {
                return v0.getValue();
            }))).asMap().entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
                return v0.getKey();
            }, entry2 -> {
                return getDebugInfo((Collection) entry2.getValue());
            })));
        }

        private String getDebugInfo(Collection<TaskDescriptor> collection) {
            int size = collection.size();
            Stats of = Stats.of(collection.stream().mapToLong((v0) -> {
                return v0.getRetainedSizeInBytes();
            }));
            Set<PlanNodeId> set = (Set) collection.stream().flatMap(taskDescriptor -> {
                return taskDescriptor.getSplits().keySet().stream();
            }).collect(ImmutableSet.toImmutableSet());
            HashMap hashMap = new HashMap();
            for (PlanNodeId planNodeId : set) {
                Stats of2 = Stats.of(collection.stream().mapToLong(taskDescriptor2 -> {
                    return ((Collection) taskDescriptor2.getSplits().asMap().get(planNodeId)).size();
                }));
                Stats of3 = Stats.of(collection.stream().flatMap(taskDescriptor3 -> {
                    return taskDescriptor3.getSplits().get(planNodeId).stream();
                }).mapToLong((v0) -> {
                    return v0.getRetainedSizeInBytes();
                }));
                hashMap.put(planNodeId, "{splitCountMean=%s, splitCountStdDev=%s, splitSizeMean=%s, splitSizeStdDev=%s}".formatted(Double.valueOf(of2.mean()), Double.valueOf(of2.populationStandardDeviation()), Double.valueOf(of3.mean()), Double.valueOf(of3.populationStandardDeviation())));
            }
            return "[taskDescriptorsCount=%s, taskDescriptorsRetainedSizeMean=%s, taskDescriptorsRetainedSizeStdDev=%s, splits=%s]".formatted(Integer.valueOf(size), Double.valueOf(of.mean()), Double.valueOf(of.populationStandardDeviation()), hashMap);
        }

        private void fail(RuntimeException runtimeException) {
            if (this.failure == null) {
                this.descriptors.clear();
                this.reservedBytes = 0L;
                this.failure = runtimeException;
            }
        }

        private void throwIfFailed() {
            if (this.failure != null) {
                throw this.failure;
            }
        }
    }

    @Inject
    public TaskDescriptorStorage(QueryManagerConfig queryManagerConfig) {
        this(queryManagerConfig.getFaultTolerantExecutionTaskDescriptorStorageMaxMemory());
    }

    public TaskDescriptorStorage(DataSize dataSize) {
        this.storages = new HashMap();
        this.maxMemoryInBytes = dataSize.toBytes();
    }

    public synchronized void initialize(QueryId queryId) {
        TaskDescriptors taskDescriptors = new TaskDescriptors();
        Verify.verify(this.storages.putIfAbsent(queryId, taskDescriptors) == null, "storage is already initialized for query: %s", queryId);
        updateMemoryReservation(taskDescriptors.getReservedBytes());
    }

    public synchronized void put(StageId stageId, TaskDescriptor taskDescriptor) {
        TaskDescriptors taskDescriptors = this.storages.get(stageId.getQueryId());
        if (taskDescriptors == null) {
            return;
        }
        long reservedBytes = taskDescriptors.getReservedBytes();
        taskDescriptors.put(stageId, taskDescriptor.getPartitionId(), taskDescriptor);
        updateMemoryReservation(taskDescriptors.getReservedBytes() - reservedBytes);
    }

    public synchronized Optional<TaskDescriptor> get(StageId stageId, int i) {
        TaskDescriptors taskDescriptors = this.storages.get(stageId.getQueryId());
        return taskDescriptors == null ? Optional.empty() : Optional.of(taskDescriptors.get(stageId, i));
    }

    public synchronized void remove(StageId stageId, int i) {
        TaskDescriptors taskDescriptors = this.storages.get(stageId.getQueryId());
        if (taskDescriptors == null) {
            return;
        }
        long reservedBytes = taskDescriptors.getReservedBytes();
        taskDescriptors.remove(stageId, i);
        updateMemoryReservation(taskDescriptors.getReservedBytes() - reservedBytes);
    }

    public synchronized void destroy(QueryId queryId) {
        TaskDescriptors remove = this.storages.remove(queryId);
        if (remove != null) {
            updateMemoryReservation(-remove.getReservedBytes());
        }
    }

    private synchronized void updateMemoryReservation(long j) {
        this.reservedBytes += j;
        if (j <= 0) {
            return;
        }
        while (this.reservedBytes > this.maxMemoryInBytes) {
            QueryId queryId = (QueryId) this.storages.entrySet().stream().max(Comparator.comparingLong(entry -> {
                return ((TaskDescriptors) entry.getValue()).getReservedBytes();
            })).map((v0) -> {
                return v0.getKey();
            }).orElseThrow(() -> {
                return new VerifyException(String.format("storage is empty but reservedBytes (%s) is still greater than maxMemoryInBytes (%s)", Long.valueOf(this.reservedBytes), Long.valueOf(this.maxMemoryInBytes)));
            });
            TaskDescriptors taskDescriptors = this.storages.get(queryId);
            long reservedBytes = taskDescriptors.getReservedBytes();
            log.info("Failing query %s; reclaiming %s of %s task descriptor memory from %s queries; extraStorageInfo=%s", new Object[]{queryId, Long.valueOf(taskDescriptors.getReservedBytes()), DataSize.succinctBytes(this.reservedBytes), Integer.valueOf(this.storages.size()), taskDescriptors.getDebugInfo()});
            taskDescriptors.fail(new TrinoException(StandardErrorCode.EXCEEDED_TASK_DESCRIPTOR_STORAGE_CAPACITY, String.format("Task descriptor storage capacity has been exceeded: %s > %s", DataSize.succinctBytes(this.maxMemoryInBytes), DataSize.succinctBytes(this.reservedBytes))));
            this.reservedBytes += taskDescriptors.getReservedBytes() - reservedBytes;
        }
    }

    @Managed
    public synchronized long getReservedBytes() {
        return this.reservedBytes;
    }
}
