package net.recommenders.rival.evaluation.metric.ranking;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.recommenders.rival.core.DataModel;
import net.recommenders.rival.evaluation.Pair;
import net.recommenders.rival.evaluation.metric.EvaluationMetric;

/* loaded from: input_file:net/recommenders/rival/evaluation/metric/ranking/PopularityStratifiedRecall.class */
public class PopularityStratifiedRecall<U, I> extends AbstractRankingMetric<U, I> implements EvaluationMetric<U> {
    private Map<Integer, Map<U, Double>> userRecallAtCutoff;
    private Map<U, Double> userTotalRecall;
    private Map<I, Integer> observedItemRelevance;
    private double gamma;

    public PopularityStratifiedRecall(DataModel<U, I> dataModel, DataModel<U, I> dataModel2, double d, Map<I, Integer> map) {
        this(dataModel, dataModel2, 1.0d, d, map);
    }

    public PopularityStratifiedRecall(DataModel<U, I> dataModel, DataModel<U, I> dataModel2, double d, double d2, Map<I, Integer> map) {
        this(dataModel, dataModel2, d, new int[0], d2, map);
    }

    public PopularityStratifiedRecall(DataModel<U, I> dataModel, DataModel<U, I> dataModel2, double d, int[] iArr, double d2, Map<I, Integer> map) {
        super(dataModel, dataModel2, d, iArr);
        this.gamma = d2;
        this.observedItemRelevance = map;
    }

    @Override // net.recommenders.rival.evaluation.metric.EvaluationMetric
    public void compute() {
        if (Double.isNaN(getValue())) {
            iniCompute();
            Map<U, List<Pair<I, Double>>> processDataAsRankedTestRelevance = processDataAsRankedTestRelevance();
            this.userRecallAtCutoff = new HashMap();
            this.userTotalRecall = new HashMap();
            double d = 0.0d;
            for (Map.Entry<U, List<Pair<I, Double>>> entry : processDataAsRankedTestRelevance.entrySet()) {
                U key = entry.getKey();
                double d2 = 0.0d;
                int i = 0;
                for (Pair<I, Double> pair : entry.getValue()) {
                    I first = pair.getFirst();
                    i++;
                    if (computeBinaryPrecision(pair.getSecond().doubleValue()) > 0.0d) {
                        d2 += getPopularityStratificationWeight(first);
                    }
                    for (int i2 : getCutoffs()) {
                        if (i == i2) {
                            Map<U, Double> map = this.userRecallAtCutoff.get(Integer.valueOf(i2));
                            if (map == null) {
                                map = new HashMap();
                                this.userRecallAtCutoff.put(Integer.valueOf(i2), map);
                            }
                            map.put(key, Double.valueOf(d2));
                        }
                    }
                }
                for (int i3 : getCutoffs()) {
                    if (i <= i3) {
                        Map<U, Double> map2 = this.userRecallAtCutoff.get(Integer.valueOf(i3));
                        if (map2 == null) {
                            map2 = new HashMap();
                            this.userRecallAtCutoff.put(Integer.valueOf(i3), map2);
                        }
                        map2.put(key, Double.valueOf(d2));
                    }
                }
                if (!Double.isNaN(d2)) {
                    setValue(getValue() + d2);
                    getMetricPerUser().put(key, Double.valueOf(d2 / d2));
                    this.userTotalRecall.put(key, Double.valueOf(d2));
                    d += d2;
                }
            }
            setValue(getValue() / d);
        }
    }

    private double getPopularityStratificationWeight(I i) {
        if (this.observedItemRelevance.containsKey(i)) {
            return Math.pow(1.0d * this.observedItemRelevance.get(i).intValue(), ((-1.0d) * this.gamma) / (this.gamma + 1.0d));
        }
        return 0.0d;
    }

    @Override // net.recommenders.rival.evaluation.metric.ranking.AbstractRankingMetric
    public double getValueAt(int i) {
        if (!this.userRecallAtCutoff.containsKey(Integer.valueOf(i))) {
            return Double.NaN;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (U u : this.userRecallAtCutoff.get(Integer.valueOf(i)).keySet()) {
            double valueAt = getValueAt(u, i);
            double doubleValue = this.userTotalRecall.get(u).doubleValue();
            if (!Double.isNaN(valueAt)) {
                d2 += valueAt * doubleValue;
                d += doubleValue;
            }
        }
        return d == 0.0d ? 0.0d : d2 / d;
    }

    @Override // net.recommenders.rival.evaluation.metric.ranking.AbstractRankingMetric
    public double getValueAt(U u, int i) {
        if (this.userRecallAtCutoff.containsKey(Integer.valueOf(i)) && this.userRecallAtCutoff.get(Integer.valueOf(i)).containsKey(u)) {
            return this.userRecallAtCutoff.get(Integer.valueOf(i)).get(u).doubleValue() / this.userTotalRecall.get(u).doubleValue();
        }
        return Double.NaN;
    }

    public String toString() {
        return "PopularityStratifiedRecall_" + this.gamma + "_" + getRelevanceThreshold();
    }
}
