package com.facebook.presto.ml;

import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/ml/TestModelSerialization.class */
public class TestModelSerialization {
    @Test
    public void testSvmClassifier() {
        SvmClassifier svmClassifier = new SvmClassifier();
        svmClassifier.train(TestUtils.getDataset());
        Model deserialize = ModelUtils.deserialize(ModelUtils.serialize(svmClassifier));
        Assert.assertNotNull(deserialize, "deserialization failed");
        Assert.assertTrue(deserialize instanceof SvmClassifier, "deserialized model is not a svm");
    }

    @Test
    public void testSvmRegressor() {
        SvmRegressor svmRegressor = new SvmRegressor();
        svmRegressor.train(TestUtils.getDataset());
        Model deserialize = ModelUtils.deserialize(ModelUtils.serialize(svmRegressor));
        Assert.assertNotNull(deserialize, "deserialization failed");
        Assert.assertTrue(deserialize instanceof SvmRegressor, "deserialized model is not a svm");
    }

    @Test
    public void testRegressorFeatureTransformer() {
        RegressorFeatureTransformer regressorFeatureTransformer = new RegressorFeatureTransformer(new SvmRegressor(), new FeatureVectorUnitNormalizer());
        regressorFeatureTransformer.train(TestUtils.getDataset());
        Model deserialize = ModelUtils.deserialize(ModelUtils.serialize(regressorFeatureTransformer));
        Assert.assertNotNull(deserialize, "deserialization failed");
        Assert.assertTrue(deserialize instanceof RegressorFeatureTransformer, "deserialized model is not a regressor feature transformer");
    }

    @Test
    public void testClassifierFeatureTransformer() {
        ClassifierFeatureTransformer classifierFeatureTransformer = new ClassifierFeatureTransformer(new SvmClassifier(), new FeatureVectorUnitNormalizer());
        classifierFeatureTransformer.train(TestUtils.getDataset());
        Model deserialize = ModelUtils.deserialize(ModelUtils.serialize(classifierFeatureTransformer));
        Assert.assertNotNull(deserialize, "deserialization failed");
        Assert.assertTrue(deserialize instanceof ClassifierFeatureTransformer, "deserialized model is not a classifier feature transformer");
    }

    @Test
    public void testSerializationIds() {
        Assert.assertEquals(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(SvmClassifier.class)).intValue(), 1);
        Assert.assertEquals(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(SvmRegressor.class)).intValue(), 2);
        Assert.assertEquals(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(FeatureVectorUnitNormalizer.class)).intValue(), 3);
        Assert.assertEquals(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(ClassifierFeatureTransformer.class)).intValue(), 4);
        Assert.assertEquals(((Integer) ModelUtils.MODEL_SERIALIZATION_IDS.get(RegressorFeatureTransformer.class)).intValue(), 5);
    }
}
