package com.facebook.presto.grpc;

import com.facebook.airlift.concurrent.MoreFutures;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockEncodingSerde;
import com.facebook.presto.common.function.SqlFunctionResult;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.grpc.api.udf.GrpcUtils;
import com.facebook.presto.grpc.udf.GrpcFunctionHandle;
import com.facebook.presto.grpc.udf.GrpcUdfInvokeGrpc;
import com.facebook.presto.grpc.udf.GrpcUdfPage;
import com.facebook.presto.grpc.udf.GrpcUdfPageFormat;
import com.facebook.presto.grpc.udf.GrpcUdfRequest;
import com.facebook.presto.grpc.udf.GrpcUdfResult;
import com.facebook.presto.spi.function.FunctionImplementationType;
import com.facebook.presto.spi.function.RemoteScalarFunctionImplementation;
import com.facebook.presto.spi.function.RoutineCharacteristics;
import com.facebook.presto.spi.function.SqlFunctionExecutor;
import com.facebook.presto.spi.function.SqlFunctionHandle;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.inject.Inject;
import io.grpc.ManagedChannelBuilder;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;

/* loaded from: input_file:com/facebook/presto/grpc/GrpcSqlFunctionExecutor.class */
public class GrpcSqlFunctionExecutor implements SqlFunctionExecutor {
    private static final int DEFAULT_RETRY_ATTEMPTS = 3;
    private final Map<RoutineCharacteristics.Language, GrpcSqlFunctionExecutionConfig> grpcUdfConfigs;
    private final Map<RoutineCharacteristics.Language, GrpcUdfInvokeGrpc.GrpcUdfInvokeFutureStub> futureStubs = new HashMap();
    private BlockEncodingSerde blockEncodingSerde;

    @Inject
    public GrpcSqlFunctionExecutor(Map<RoutineCharacteristics.Language, GrpcSqlFunctionExecutionConfig> map) {
        this.grpcUdfConfigs = (Map) Objects.requireNonNull(map, "grpcUdfConfigs is null");
        map.entrySet().forEach(entry -> {
            this.futureStubs.put(entry.getKey(), GrpcUdfInvokeGrpc.newFutureStub(ManagedChannelBuilder.forTarget(((GrpcSqlFunctionExecutionConfig) entry.getValue()).getGrpcAddress()).build()));
        });
    }

    public FunctionImplementationType getImplementationType() {
        return FunctionImplementationType.GRPC;
    }

    public void setBlockEncodingSerde(BlockEncodingSerde blockEncodingSerde) {
        Preconditions.checkState(this.blockEncodingSerde == null, "blockEncodingSerde already set");
        Preconditions.checkArgument(blockEncodingSerde != null, "blockEncodingSerde is null");
        this.blockEncodingSerde = blockEncodingSerde;
    }

    public CompletableFuture<SqlFunctionResult> executeFunction(String str, RemoteScalarFunctionImplementation remoteScalarFunctionImplementation, Page page, List<Integer> list, List<Type> list2, Type type) {
        GrpcUdfPage buildGrpcUdfPage = buildGrpcUdfPage(page, list, this.grpcUdfConfigs.get(remoteScalarFunctionImplementation.getLanguage()).getGrpcUdfPageFormat());
        SqlFunctionHandle functionHandle = remoteScalarFunctionImplementation.getFunctionHandle();
        SqlFunctionId functionId = functionHandle.getFunctionId();
        return invokeUdfWithRetry(this.futureStubs.get(remoteScalarFunctionImplementation.getLanguage()), GrpcUdfRequest.newBuilder().setSource(str).setGrpcFunctionHandle(GrpcFunctionHandle.newBuilder().setFunctionName(functionId.getFunctionName().toString()).addAllArgumentTypes((Iterable) functionId.getArgumentTypes().stream().map((v0) -> {
            return v0.toString();
        }).collect(ImmutableList.toImmutableList())).setReturnType(type.toString()).setVersion(functionHandle.getVersion()).m42build()).setInputs(buildGrpcUdfPage).m191build()).thenApply(grpcUdfResult -> {
            return toSqlFunctionResult(grpcUdfResult);
        });
    }

    private CompletableFuture<GrpcUdfResult> invokeUdf(GrpcUdfInvokeGrpc.GrpcUdfInvokeFutureStub grpcUdfInvokeFutureStub, GrpcUdfRequest grpcUdfRequest) {
        try {
            return MoreFutures.toCompletableFuture(grpcUdfInvokeFutureStub.invokeUdf(grpcUdfRequest));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private CompletableFuture<GrpcUdfResult> invokeUdfWithRetry(GrpcUdfInvokeGrpc.GrpcUdfInvokeFutureStub grpcUdfInvokeFutureStub, GrpcUdfRequest grpcUdfRequest) {
        CompletableFuture<GrpcUdfResult> invokeUdf = invokeUdf(grpcUdfInvokeFutureStub, grpcUdfRequest);
        for (int i = 0; i < 3; i++) {
            invokeUdf = invokeUdf.thenApply((v0) -> {
                return CompletableFuture.completedFuture(v0);
            }).exceptionally((Function<Throwable, ? extends U>) th -> {
                return invokeUdf(grpcUdfInvokeFutureStub, grpcUdfRequest);
            }).thenCompose(Function.identity());
        }
        return invokeUdf;
    }

    private GrpcUdfPage buildGrpcUdfPage(Page page, List<Integer> list, GrpcUdfPageFormat grpcUdfPageFormat) {
        Block[] blockArr = new Block[list.size()];
        for (int i = 0; i < list.size(); i++) {
            blockArr[i] = page.getBlock(list.get(i).intValue());
        }
        switch (grpcUdfPageFormat) {
            case Presto:
                Preconditions.checkState(this.blockEncodingSerde != null, "blockEncodingSerde not set");
                return GrpcUtils.toGrpcUdfPage(grpcUdfPageFormat, GrpcUtils.toGrpcSerializedPage(this.blockEncodingSerde, Page.wrapBlocksWithoutCopy(page.getPositionCount(), blockArr)));
            default:
                throw new IllegalArgumentException(String.format("Unknown page format: %s", grpcUdfPageFormat));
        }
    }

    private SqlFunctionResult toSqlFunctionResult(GrpcUdfResult grpcUdfResult) {
        Preconditions.checkState(this.blockEncodingSerde != null, "blockEncodingSerde not set");
        GrpcUdfPage result = grpcUdfResult.getResult();
        switch (result.getGrpcUdfPageFormat()) {
            case Presto:
                return new SqlFunctionResult(GrpcUtils.toPrestoPage(this.blockEncodingSerde, result.getGrpcSerializedPage()).getBlock(0), grpcUdfResult.getUdfStats().getTotalCpuTimeMs());
            default:
                throw new IllegalArgumentException(String.format("Unknown page format: %s", result.getGrpcUdfPageFormat()));
        }
    }
}
