package com.facebook.presto.ml;

import com.facebook.presto.ml.type.ClassifierType;
import com.facebook.presto.ml.type.RegressorType;
import com.facebook.presto.operator.scalar.ScalarFunction;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.VariableWidthBlockBuilder;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.type.SqlType;
import com.facebook.presto.util.Types;
import com.google.common.base.Preconditions;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.hash.HashCode;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;

/* loaded from: input_file:com/facebook/presto/ml/MLFunctions.class */
public final class MLFunctions {
    private static final Cache<HashCode, Model> MODEL_CACHE = CacheBuilder.newBuilder().maximumSize(5).build();
    private static final String MAP_BIGINT_DOUBLE = "map(bigint,double)";

    private MLFunctions() {
    }

    @ScalarFunction("classify")
    @SqlType("varchar")
    public static Slice varcharClassify(@SqlType("map(bigint,double)") Block block, @SqlType("Classifier<varchar>") Slice slice) {
        FeatureVector features = ModelUtils.toFeatures(block);
        Model orLoadModel = getOrLoadModel(slice);
        Preconditions.checkArgument(orLoadModel.getType().equals(ClassifierType.VARCHAR_CLASSIFIER), "model is not a classifier<varchar>");
        return Slices.utf8Slice((String) ((Classifier) Types.checkType(orLoadModel, Classifier.class, "model")).classify(features));
    }

    @ScalarFunction
    @SqlType("bigint")
    public static long classify(@SqlType("map(bigint,double)") Block block, @SqlType("Classifier<bigint>") Slice slice) {
        FeatureVector features = ModelUtils.toFeatures(block);
        Preconditions.checkArgument(getOrLoadModel(slice).getType().equals(ClassifierType.BIGINT_CLASSIFIER), "model is not a classifier<bigint>");
        return ((Integer) ((Classifier) Types.checkType(r0, Classifier.class, "model")).classify(features)).intValue();
    }

    @ScalarFunction
    @SqlType("double")
    public static double regress(@SqlType("map(bigint,double)") Block block, @SqlType("Regressor") Slice slice) {
        FeatureVector features = ModelUtils.toFeatures(block);
        Model orLoadModel = getOrLoadModel(slice);
        Preconditions.checkArgument(orLoadModel.getType().equals(RegressorType.REGRESSOR), "model is not a regressor");
        return ((Regressor) Types.checkType(orLoadModel, Regressor.class, "model")).regress(features);
    }

    private static Model getOrLoadModel(Slice slice) {
        HashCode modelHash = ModelUtils.modelHash(slice);
        Model model = (Model) MODEL_CACHE.getIfPresent(modelHash);
        if (model == null) {
            model = ModelUtils.deserialize(slice);
            MODEL_CACHE.put(modelHash, model);
        }
        return model;
    }

    @ScalarFunction
    @SqlType(MAP_BIGINT_DOUBLE)
    public static Block features(@SqlType("double") double d) {
        return featuresHelper(d);
    }

    @ScalarFunction
    @SqlType(MAP_BIGINT_DOUBLE)
    public static Block features(@SqlType("double") double d, @SqlType("double") double d2) {
        return featuresHelper(d, d2);
    }

    @ScalarFunction
    @SqlType(MAP_BIGINT_DOUBLE)
    public static Block features(@SqlType("double") double d, @SqlType("double") double d2, @SqlType("double") double d3) {
        return featuresHelper(d, d2, d3);
    }

    @ScalarFunction
    @SqlType(MAP_BIGINT_DOUBLE)
    public static Block features(@SqlType("double") double d, @SqlType("double") double d2, @SqlType("double") double d3, @SqlType("double") double d4) {
        return featuresHelper(d, d2, d3, d4);
    }

    @ScalarFunction
    @SqlType(MAP_BIGINT_DOUBLE)
    public static Block features(@SqlType("double") double d, @SqlType("double") double d2, @SqlType("double") double d3, @SqlType("double") double d4, @SqlType("double") double d5) {
        return featuresHelper(d, d2, d3, d4, d5);
    }

    @ScalarFunction
    @SqlType(MAP_BIGINT_DOUBLE)
    public static Block features(@SqlType("double") double d, @SqlType("double") double d2, @SqlType("double") double d3, @SqlType("double") double d4, @SqlType("double") double d5, @SqlType("double") double d6) {
        return featuresHelper(d, d2, d3, d4, d5, d6);
    }

    @ScalarFunction
    @SqlType(MAP_BIGINT_DOUBLE)
    public static Block features(@SqlType("double") double d, @SqlType("double") double d2, @SqlType("double") double d3, @SqlType("double") double d4, @SqlType("double") double d5, @SqlType("double") double d6, @SqlType("double") double d7) {
        return featuresHelper(d, d2, d3, d4, d5, d6, d7);
    }

    @ScalarFunction
    @SqlType(MAP_BIGINT_DOUBLE)
    public static Block features(@SqlType("double") double d, @SqlType("double") double d2, @SqlType("double") double d3, @SqlType("double") double d4, @SqlType("double") double d5, @SqlType("double") double d6, @SqlType("double") double d7, @SqlType("double") double d8) {
        return featuresHelper(d, d2, d3, d4, d5, d6, d7, d8);
    }

    @ScalarFunction
    @SqlType(MAP_BIGINT_DOUBLE)
    public static Block features(@SqlType("double") double d, @SqlType("double") double d2, @SqlType("double") double d3, @SqlType("double") double d4, @SqlType("double") double d5, @SqlType("double") double d6, @SqlType("double") double d7, @SqlType("double") double d8, @SqlType("double") double d9) {
        return featuresHelper(d, d2, d3, d4, d5, d6, d7, d8, d9);
    }

    @ScalarFunction
    @SqlType(MAP_BIGINT_DOUBLE)
    public static Block features(@SqlType("double") double d, @SqlType("double") double d2, @SqlType("double") double d3, @SqlType("double") double d4, @SqlType("double") double d5, @SqlType("double") double d6, @SqlType("double") double d7, @SqlType("double") double d8, @SqlType("double") double d9, @SqlType("double") double d10) {
        return featuresHelper(d, d2, d3, d4, d5, d6, d7, d8, d9, d10);
    }

    private static Block featuresHelper(double... dArr) {
        VariableWidthBlockBuilder variableWidthBlockBuilder = new VariableWidthBlockBuilder(new BlockBuilderStatus(), dArr.length, 16);
        for (int i = 0; i < dArr.length; i++) {
            BigintType.BIGINT.writeLong(variableWidthBlockBuilder, i);
            DoubleType.DOUBLE.writeDouble(variableWidthBlockBuilder, dArr[i]);
        }
        return variableWidthBlockBuilder.build();
    }
}
