package com.facebook.presto.ml;

import com.facebook.presto.ml.type.ClassifierType;
import com.facebook.presto.operator.aggregation.AggregationFunction;
import com.facebook.presto.operator.aggregation.CombineFunction;
import com.facebook.presto.operator.aggregation.InputFunction;
import com.facebook.presto.operator.aggregation.OutputFunction;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.type.SqlType;
import io.airlift.slice.Slice;

@AggregationFunction(value = "learn_libsvm_classifier", decomposable = false)
/* loaded from: input_file:com/facebook/presto/ml/LearnLibSvmVarcharClassifierAggregation.class */
public final class LearnLibSvmVarcharClassifierAggregation {
    private LearnLibSvmVarcharClassifierAggregation() {
    }

    @InputFunction
    public static void input(LearnState learnState, @SqlType("varchar") Slice slice, @SqlType("map<bigint,double>") Slice slice2, @SqlType("varchar") Slice slice3) {
        learnState.getLabels().add(Double.valueOf(learnState.enumerateLabel(slice.toStringUtf8())));
        FeatureVector features = ModelUtils.toFeatures(slice2);
        learnState.addMemoryUsage(features.getEstimatedSize());
        learnState.getFeatureVectors().add(features);
        learnState.setParameters(slice3);
    }

    @CombineFunction
    public static void combine(LearnState learnState, LearnState learnState2) {
        throw new UnsupportedOperationException("LEARN must run on a single machine");
    }

    @OutputFunction("Classifier<varchar>")
    public static void output(LearnState learnState, BlockBuilder blockBuilder) {
        Dataset dataset = new Dataset(learnState.getLabels(), learnState.getFeatureVectors(), learnState.getLabelEnumeration().inverse());
        StringClassifierAdapter stringClassifierAdapter = new StringClassifierAdapter(new ClassifierFeatureTransformer(new SvmClassifier(LibSvmUtils.parseParameters(learnState.getParameters().toStringUtf8())), new FeatureUnitNormalizer()));
        stringClassifierAdapter.train(dataset);
        ClassifierType.VARCHAR_CLASSIFIER.writeSlice(blockBuilder, ModelUtils.serialize(stringClassifierAdapter));
    }
}
