package org.tribuo.common.xgboost;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;

/* loaded from: input_file:org/tribuo/common/xgboost/XGBoostFeatureImportance.class */
public class XGBoostFeatureImportance {
    private final Booster booster;
    private final ImmutableFeatureMap featureMap;
    private final Model<?> model;

    /* loaded from: input_file:org/tribuo/common/xgboost/XGBoostFeatureImportance$XGBoostFeatureImportanceInstance.class */
    public static class XGBoostFeatureImportanceInstance {
        private final String featureName;
        private final double gain;
        private final double cover;
        private final double weight;
        private final double totalGain;
        private final double totalCover;

        XGBoostFeatureImportanceInstance(String str, double d, double d2, double d3, double d4, double d5) {
            this.featureName = str;
            this.gain = d;
            this.cover = d2;
            this.weight = d3;
            this.totalGain = d4;
            this.totalCover = d5;
        }

        public String getFeatureName() {
            return this.featureName;
        }

        public double getGain() {
            return this.gain;
        }

        public double getCover() {
            return this.cover;
        }

        public double getWeight() {
            return this.weight;
        }

        public double getTotalGain() {
            return this.totalGain;
        }

        public double getTotalCover() {
            return this.totalCover;
        }

        public String toString() {
            return String.format("XGBoostFeatureImportanceRecord(feature=%s, gain=%.2f, cover=%.2f, weight=%.2f, totalGain=%.2f, totalCover=%.2f)", this.featureName, Double.valueOf(this.gain), Double.valueOf(this.cover), Double.valueOf(this.weight), Double.valueOf(this.totalGain), Double.valueOf(this.totalCover));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public XGBoostFeatureImportance(Booster booster, Model<?> model) {
        this.booster = booster;
        this.model = model;
        this.featureMap = model.getFeatureIDMap();
    }

    private String translateFeatureId(String str) {
        return this.featureMap.get(Integer.parseInt(str.substring(1))).getName();
    }

    private Stream<Map.Entry<String, Double>> getImportanceStream(String str) {
        try {
            return this.booster.getScore("", str).entrySet().stream().sorted(Comparator.comparingDouble(entry -> {
                return ((Double) entry.getValue()).doubleValue();
            }).reversed());
        } catch (XGBoostError e) {
            throw new IllegalStateException("Error generating feature importance for " + str + " caused by", e);
        }
    }

    private LinkedHashMap<String, Double> coalesceImportanceStream(Stream<Map.Entry<String, Double>> stream) {
        return (LinkedHashMap) stream.collect(Collectors.toMap(entry -> {
            return translateFeatureId((String) entry.getKey());
        }, (v0) -> {
            return v0.getValue();
        }, (d, d2) -> {
            return d;
        }, LinkedHashMap::new));
    }

    public LinkedHashMap<String, Double> getGain() {
        return coalesceImportanceStream(getImportanceStream("gain"));
    }

    public LinkedHashMap<String, Double> getGain(int i) {
        return coalesceImportanceStream(getImportanceStream("gain").limit(i));
    }

    public LinkedHashMap<String, Double> getCover() {
        return coalesceImportanceStream(getImportanceStream("cover"));
    }

    public LinkedHashMap<String, Double> getCover(int i) {
        return coalesceImportanceStream(getImportanceStream("cover").limit(i));
    }

    public LinkedHashMap<String, Double> getWeight() {
        return coalesceImportanceStream(getImportanceStream("weight"));
    }

    public LinkedHashMap<String, Double> getWeight(int i) {
        return coalesceImportanceStream(getImportanceStream("weight").limit(i));
    }

    public LinkedHashMap<String, Double> getTotalGain() {
        return coalesceImportanceStream(getImportanceStream("total_gain"));
    }

    public LinkedHashMap<String, Double> getTotalGain(int i) {
        return coalesceImportanceStream(getImportanceStream("total_gain").limit(i));
    }

    public LinkedHashMap<String, Double> getTotalCover() {
        return coalesceImportanceStream(getImportanceStream("total_cover"));
    }

    public LinkedHashMap<String, Double> getTotalCover(int i) {
        return coalesceImportanceStream(getImportanceStream("total_cover").limit(i));
    }

    public List<XGBoostFeatureImportanceInstance> getImportances() {
        Map map = (Map) Stream.of((Object[]) new String[]{"gain", "cover", "weight", "total_gain", "total_cover"}).map(str -> {
            return new AbstractMap.SimpleEntry(str, coalesceImportanceStream(getImportanceStream(str)));
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }));
        return (List) new ArrayList(((LinkedHashMap) map.get("gain")).keySet()).stream().map(str2 -> {
            return new XGBoostFeatureImportanceInstance(str2, ((Double) ((LinkedHashMap) map.get("gain")).get(str2)).doubleValue(), ((Double) ((LinkedHashMap) map.get("cover")).get(str2)).doubleValue(), ((Double) ((LinkedHashMap) map.get("weight")).get(str2)).doubleValue(), ((Double) ((LinkedHashMap) map.get("total_gain")).get(str2)).doubleValue(), ((Double) ((LinkedHashMap) map.get("total_cover")).get(str2)).doubleValue());
        }).collect(Collectors.toList());
    }

    public List<XGBoostFeatureImportanceInstance> getImportances(int i) {
        Map map = (Map) Stream.of((Object[]) new String[]{"gain", "cover", "weight", "total_gain", "total_cover"}).map(str -> {
            return new AbstractMap.SimpleEntry(str, coalesceImportanceStream(getImportanceStream(str)));
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }));
        return (List) new ArrayList(((LinkedHashMap) map.get("gain")).keySet()).subList(0, Math.min(((LinkedHashMap) map.get("gain")).keySet().size(), i)).stream().map(str2 -> {
            return new XGBoostFeatureImportanceInstance(str2, ((Double) ((LinkedHashMap) map.get("gain")).get(str2)).doubleValue(), ((Double) ((LinkedHashMap) map.get("cover")).get(str2)).doubleValue(), ((Double) ((LinkedHashMap) map.get("weight")).get(str2)).doubleValue(), ((Double) ((LinkedHashMap) map.get("total_gain")).get(str2)).doubleValue(), ((Double) ((LinkedHashMap) map.get("total_cover")).get(str2)).doubleValue());
        }).collect(Collectors.toList());
    }

    public String toString() {
        return String.format("XGBoostFeatureImportance(model=%s)", this.model.toString());
    }
}
