package org.tribuo.multilabel.sgd.linear;

import java.util.HashMap;
import java.util.HashSet;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.sgd.AbstractLinearSGDModel;
import org.tribuo.common.sgd.AbstractSGDModel;
import org.tribuo.math.LinearParameters;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/multilabel/sgd/linear/LinearSGDModel.class */
public class LinearSGDModel extends AbstractLinearSGDModel<MultiLabel> {
    private static final long serialVersionUID = 2;
    private final VectorNormalizer normalizer;
    private final double threshold;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LinearSGDModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<MultiLabel> immutableOutputInfo, LinearParameters linearParameters, VectorNormalizer vectorNormalizer, boolean z, double d) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, linearParameters, z);
        this.normalizer = vectorNormalizer;
        this.threshold = d;
    }

    public Prediction<MultiLabel> predict(Example<MultiLabel> example) {
        AbstractSGDModel.PredAndActive predictSingle = predictSingle(example);
        DenseVector denseVector = predictSingle.prediction;
        denseVector.normalize(this.normalizer);
        HashMap hashMap = new HashMap();
        HashSet hashSet = new HashSet();
        for (int i = 0; i < denseVector.size(); i++) {
            String labelString = this.outputIDInfo.getOutput(i).getLabelString();
            double d = denseVector.get(i);
            Label label = new Label(this.outputIDInfo.getOutput(i).getLabelString(), d);
            if (d > this.threshold) {
                hashSet.add(label);
            }
            hashMap.put(labelString, new MultiLabel(label));
        }
        return new Prediction<>(new MultiLabel(hashSet), hashMap, predictSingle.numActiveFeatures - 1, example, this.generatesProbabilities);
    }

    protected String getDimensionName(int i) {
        return this.outputIDInfo.getOutput(i).getLabelString();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public LinearSGDModel m0copy(String str, ModelProvenance modelProvenance) {
        return new LinearSGDModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.modelParameters.copy(), this.normalizer, this.generatesProbabilities, this.threshold);
    }
}
