package org.tribuo.classification.baseline;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.ImmutableLabelInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.baseline.DummyClassifierTrainer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/classification/baseline/DummyClassifierModel.class */
public class DummyClassifierModel extends Model<Label> {
    private static final long serialVersionUID = 1;
    private final DummyClassifierTrainer.DummyType dummyType;
    private final Label constantLabel;
    private final double[] cdf;
    private final Random rng;
    private final long seed;

    /* JADX INFO: Access modifiers changed from: package-private */
    public DummyClassifierModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo) {
        super("dummy-MOST_FREQUENT-classifier", modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.dummyType = DummyClassifierTrainer.DummyType.MOST_FREQUENT;
        this.constantLabel = findMostFrequentLabel(immutableOutputInfo);
        this.cdf = null;
        this.seed = 12345L;
        this.rng = null;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DummyClassifierModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, DummyClassifierTrainer.DummyType dummyType, long j) {
        super("dummy-" + dummyType + "-classifier", modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.dummyType = dummyType;
        this.constantLabel = LabelFactory.UNKNOWN_LABEL;
        this.cdf = dummyType == DummyClassifierTrainer.DummyType.UNIFORM ? generateUniformCDF(immutableOutputInfo) : generateStratifiedCDF(immutableOutputInfo);
        this.seed = j;
        this.rng = new Random(j);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DummyClassifierModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Label> immutableOutputInfo, Label label) {
        super("dummy-CONSTANT-classifier", modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.dummyType = DummyClassifierTrainer.DummyType.CONSTANT;
        this.constantLabel = label;
        this.cdf = null;
        this.seed = 12345L;
        this.rng = null;
    }

    public Prediction<Label> predict(Example<Label> example) {
        switch (this.dummyType) {
            case CONSTANT:
            case MOST_FREQUENT:
                return new Prediction<>(this.constantLabel, 0, example);
            case UNIFORM:
            case STRATIFIED:
                return new Prediction<>(sampleLabel(this.cdf, this.outputIDInfo, this.rng), 0, example);
            default:
                throw new IllegalStateException("Unknown dummyType " + this.dummyType);
        }
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        HashMap hashMap = new HashMap();
        if (i != 0) {
            hashMap.put("ALL_OUTPUTS", Collections.singletonList(new Pair("BIAS", Double.valueOf(1.0d))));
        }
        return hashMap;
    }

    public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
        return Optional.of(new Excuse(example, predict(example), getTopFeatures(1)));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public DummyClassifierModel m9copy(String str, ModelProvenance modelProvenance) {
        switch (this.dummyType) {
            case CONSTANT:
                return new DummyClassifierModel(modelProvenance, this.featureIDMap, this.outputIDInfo, this.constantLabel.m3copy());
            case MOST_FREQUENT:
                return new DummyClassifierModel(modelProvenance, this.featureIDMap, this.outputIDInfo);
            case UNIFORM:
            case STRATIFIED:
                return new DummyClassifierModel(modelProvenance, this.featureIDMap, this.outputIDInfo, this.dummyType, this.seed);
            default:
                throw new IllegalStateException("Unknown dummyType " + this.dummyType);
        }
    }

    private static Label sampleLabel(double[] dArr, ImmutableOutputInfo<Label> immutableOutputInfo, Random random) {
        return (Label) immutableOutputInfo.getOutput(Util.sampleFromCDF(dArr, random));
    }

    private static Label findMostFrequentLabel(ImmutableOutputInfo<Label> immutableOutputInfo) {
        Label label = null;
        long j = -1;
        ImmutableLabelInfo immutableLabelInfo = (ImmutableLabelInfo) immutableOutputInfo;
        Iterator<Pair<Integer, Label>> it = immutableLabelInfo.iterator();
        while (it.hasNext()) {
            Pair<Integer, Label> next = it.next();
            long labelCount = immutableLabelInfo.getLabelCount(((Integer) next.getA()).intValue());
            if (labelCount > j) {
                j = labelCount;
                label = (Label) next.getB();
            }
        }
        return label;
    }

    private static double[] generateUniformCDF(ImmutableOutputInfo<Label> immutableOutputInfo) {
        int size = immutableOutputInfo.getDomain().size();
        return Util.generateCDF(Util.generateUniformVector(size, 1.0d / size));
    }

    private static double[] generateStratifiedCDF(ImmutableOutputInfo<Label> immutableOutputInfo) {
        ImmutableLabelInfo immutableLabelInfo = (ImmutableLabelInfo) immutableOutputInfo;
        int size = immutableLabelInfo.getDomain().size();
        long totalObservations = immutableLabelInfo.getTotalObservations();
        double[] dArr = new double[size];
        Iterator<Pair<Integer, Label>> it = immutableLabelInfo.iterator();
        while (it.hasNext()) {
            dArr[((Integer) it.next().getA()).intValue()] = immutableLabelInfo.getLabelCount(r0) / totalObservations;
        }
        return Util.generateCDF(dArr);
    }
}
