package com.facebook.presto.ml;

import com.facebook.presto.ml.type.ClassifierType;
import com.facebook.presto.operator.aggregation.AggregationFunction;
import com.facebook.presto.operator.aggregation.CombineFunction;
import com.facebook.presto.operator.aggregation.InputFunction;
import com.facebook.presto.operator.aggregation.OutputFunction;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.type.SqlType;
import io.airlift.slice.Slice;

@AggregationFunction(value = "learn_libsvm_classifier", decomposable = false)
/* loaded from: input_file:com/facebook/presto/ml/LearnLibSvmClassifierAggregation.class */
public final class LearnLibSvmClassifierAggregation {
    private LearnLibSvmClassifierAggregation() {
    }

    @InputFunction
    public static void input(LearnState learnState, @SqlType("bigint") long j, @SqlType("map<bigint,double>") Slice slice, @SqlType("varchar") Slice slice2) {
        input(learnState, j, slice, slice2);
    }

    @InputFunction
    public static void input(LearnState learnState, @SqlType("double") double d, @SqlType("map<bigint,double>") Slice slice, @SqlType("varchar") Slice slice2) {
        learnState.getLabels().add(Double.valueOf(d));
        FeatureVector features = ModelUtils.toFeatures(slice);
        learnState.addMemoryUsage(features.getEstimatedSize());
        learnState.getFeatureVectors().add(features);
        learnState.setParameters(slice2);
    }

    @CombineFunction
    public static void combine(LearnState learnState, LearnState learnState2) {
        throw new UnsupportedOperationException("LEARN must run on a single machine");
    }

    @OutputFunction("Classifier<bigint>")
    public static void output(LearnState learnState, BlockBuilder blockBuilder) {
        Dataset dataset = new Dataset(learnState.getLabels(), learnState.getFeatureVectors(), learnState.getLabelEnumeration().inverse());
        ClassifierFeatureTransformer classifierFeatureTransformer = new ClassifierFeatureTransformer(new SvmClassifier(LibSvmUtils.parseParameters(learnState.getParameters().toStringUtf8())), new FeatureUnitNormalizer());
        classifierFeatureTransformer.train(dataset);
        ClassifierType.BIGINT_CLASSIFIER.writeSlice(blockBuilder, ModelUtils.serialize(classifierFeatureTransformer));
    }
}
