package org.jpmml.sparkml.model;

import java.util.List;
import java.util.NoSuchElementException;
import org.apache.spark.ml.classification.NaiveBayesModel;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sparkml.ClassificationModelConverter;
import org.jpmml.sparkml.VectorUtil;
import scala.collection.Iterator;

/* loaded from: input_file:org/jpmml/sparkml/model/NaiveBayesModelConverter.class */
public class NaiveBayesModelConverter extends ClassificationModelConverter<NaiveBayesModel> {
    public NaiveBayesModelConverter(NaiveBayesModel naiveBayesModel) {
        super(naiveBayesModel);
    }

    @Override // org.jpmml.sparkml.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public RegressionModel mo8encodeModel(Schema schema) {
        NaiveBayesModel naiveBayesModel = (NaiveBayesModel) getTransformer();
        String modelType = naiveBayesModel.getModelType();
        boolean z = -1;
        switch (modelType.hashCode()) {
            case 508210817:
                if (modelType.equals("multinomial")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                try {
                    for (double d : naiveBayesModel.getThresholds()) {
                        if (d != 0.0d) {
                            throw new IllegalArgumentException();
                        }
                    }
                } catch (NoSuchElementException e) {
                }
                Vector pi = naiveBayesModel.pi();
                Matrix theta = naiveBayesModel.theta();
                List<Double> list = VectorUtil.toList(pi);
                CategoricalLabel label = schema.getLabel();
                List features = schema.getFeatures();
                Iterator rowIter = theta.rowIter();
                RegressionModel normalizationMethod = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label), (List) null).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
                for (int i = 0; i < label.size(); i++) {
                    normalizationMethod.addRegressionTables(new RegressionTable[]{RegressionModelUtil.createRegressionTable(features, VectorUtil.toList((Vector) rowIter.next()), list.get(i)).setTargetCategory(label.getValue(i))});
                }
                return normalizationMethod;
            default:
                throw new IllegalArgumentException(modelType);
        }
    }
}
