package de.datexis.retrieval.eval;

import de.datexis.annotator.AnnotatorEvaluation;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Query;
import de.datexis.model.Result;
import de.datexis.retrieval.model.ScoredResult;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.linalg.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/retrieval/eval/RetrievalEvaluation.class */
public class RetrievalEvaluation extends AnnotatorEvaluation {
    protected final Logger log;
    protected double mrrsum;
    protected double mapsum;
    protected double[] precisionKsum;
    protected double[] recallKsum;
    protected double[] dcgKsum;
    protected double[] idcgKsum;
    protected double[] ndcgKsum;
    static final /* synthetic */ boolean $assertionsDisabled;

    public RetrievalEvaluation(String str) {
        super(str, Annotation.Source.GOLD, Annotation.Source.PRED);
        this.log = LoggerFactory.getLogger(getClass());
        this.mrrsum = 0.0d;
        this.mapsum = 0.0d;
        this.precisionKsum = new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d};
        this.recallKsum = new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d};
        this.dcgKsum = new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d};
        this.idcgKsum = new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d};
        this.ndcgKsum = new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d};
    }

    public void evaluateQueries(Dataset dataset) {
        evaluateQueries(dataset.getQueries());
    }

    public void evaluateQueries(Collection<Query> collection) {
        for (Query query : collection) {
            List<Result> results = query.getResults(Annotation.Source.GOLD, Result.class);
            List<Result> results2 = query.getResults(Annotation.Source.PRED, ScoredResult.class);
            double[] dArr = new double[11];
            dArr[0] = 0.0d;
            dArr[1] = 0.0d;
            dArr[2] = 0.0d;
            dArr[3] = 0.0d;
            dArr[4] = 0.0d;
            dArr[5] = 0.0d;
            dArr[6] = 0.0d;
            dArr[7] = 0.0d;
            dArr[8] = 0.0d;
            dArr[9] = 0.0d;
            dArr[10] = 0.0d;
            AtomicInteger atomicInteger = new AtomicInteger(0);
            results2.stream().forEach(scoredResult -> {
                scoredResult.setRank(Integer.valueOf(atomicInteger.incrementAndGet()));
                scoredResult.setRelevance(0);
            });
            int i = 0;
            double d = 0.0d;
            for (Result result : results) {
                i++;
                if (i <= 10) {
                    d += getDCGlog(result.getRelevance().intValue(), i);
                    dArr[i] = d;
                    double[] dArr2 = this.idcgKsum;
                    dArr2[i] = dArr2[i] + d;
                }
                results2.stream().forEach(scoredResult2 -> {
                    if (scoredResult2.matches(result)) {
                        scoredResult2.setRelevant(result.isRelevant());
                        scoredResult2.setRelevance(result.getRelevance());
                    }
                });
            }
            while (i < 10) {
                i++;
                dArr[i] = d;
                double[] dArr3 = this.idcgKsum;
                dArr3[i] = dArr3[i] + d;
            }
            if (results2.stream().filter(scoredResult3 -> {
                return scoredResult3.isRelevant();
            }).findFirst().isPresent()) {
                this.mrrsum += 1.0d / ((Result) r0.get()).getRank().intValue();
            }
            int i2 = 0;
            double d2 = 0.0d;
            double d3 = 0.0d;
            long count = results.stream().filter(result2 -> {
                return result2.isRelevant();
            }).count();
            int i3 = 0;
            for (Result result3 : results2) {
                i3++;
                if (!$assertionsDisabled && i3 != result3.getRank().intValue()) {
                    throw new AssertionError();
                }
                if (result3.isRelevant()) {
                    i2++;
                }
                if (i3 <= 10) {
                    double[] dArr4 = this.precisionKsum;
                    dArr4[i3] = dArr4[i3] + div(i2, i3);
                    double[] dArr5 = this.recallKsum;
                    dArr5[i3] = dArr5[i3] + div(i2, count);
                    d3 += getDCGlog(result3.getRelevance().intValue(), i3);
                    double[] dArr6 = this.dcgKsum;
                    dArr6[i3] = dArr6[i3] + d3;
                    double[] dArr7 = this.ndcgKsum;
                    dArr7[i3] = dArr7[i3] + (d3 / dArr[i3]);
                }
                if (result3.isRelevant()) {
                    d2 += div(i2, i3);
                }
                if (i2 >= count) {
                    break;
                }
            }
            while (i3 < 10) {
                i3++;
                double[] dArr8 = this.precisionKsum;
                dArr8[i3] = dArr8[i3] + div(i2, i3);
                double[] dArr9 = this.recallKsum;
                dArr9[i3] = dArr9[i3] + div(i2, count);
                double[] dArr10 = this.dcgKsum;
                dArr10[i3] = dArr10[i3] + d3;
                double[] dArr11 = this.ndcgKsum;
                dArr11[i3] = dArr11[i3] + (d3 / dArr[i3]);
            }
            this.mapsum += div(d2, count);
            this.countExamples++;
        }
        this.log.info("{} queries, {} examples MRR={} P@1={} P@3={} P@5={} R@1={} R@3={} MAP={}", new Object[]{Integer.valueOf(collection.size()), Integer.valueOf(this.countExamples), Double.valueOf(getMRR()), Double.valueOf(getPrecisionK(1)), Double.valueOf(getPrecisionK(3)), Double.valueOf(getPrecisionK(5)), Double.valueOf(getRecallK(1)), Double.valueOf(getRecallK(3)), Double.valueOf(getMAP())});
    }

    protected double getDCGlog(int i, int i2) {
        return (MathUtils.pow(2.0d, i) - 1.0d) / MathUtils.log2(i2 + 1);
    }

    protected double getMRR() {
        return this.mrrsum / this.countExamples;
    }

    public double getMAP() {
        return this.mapsum / this.countExamples;
    }

    public double getPrecisionK(int i) {
        if (i <= 0 || i > this.precisionKsum.length - 1) {
            throw new IllegalArgumentException("illegal argument 0 < k <= 10");
        }
        return this.precisionKsum[i] / this.countExamples;
    }

    public double getRecallK(int i) {
        if (i <= 0 || i > this.precisionKsum.length - 1) {
            throw new IllegalArgumentException("illegal argument 0 < k <= 10");
        }
        return this.recallKsum[i] / this.countExamples;
    }

    public double getDCG(int i) {
        if (i <= 0 || i > this.precisionKsum.length - 1) {
            throw new IllegalArgumentException("illegal argument 0 < k <= 10");
        }
        return this.dcgKsum[i] / this.countExamples;
    }

    protected double getIDCG(int i) {
        if (i <= 0 || i > this.precisionKsum.length - 1) {
            throw new IllegalArgumentException("illegal argument 0 < k <= 10");
        }
        return this.idcgKsum[i] / this.countExamples;
    }

    public double getNDCG(int i) {
        if (i <= 0 || i > this.precisionKsum.length - 1) {
            throw new IllegalArgumentException("illegal argument 0 < k <= 10");
        }
        return this.ndcgKsum[i] / this.countExamples;
    }

    public double getScore() {
        return getMAP();
    }

    public void calculateScores(Collection<Document> collection) {
        throw new UnsupportedOperationException("RetrievalEvaluation requires a Dataset with Queries");
    }

    public String printEvaluationStats() {
        StringBuilder sb = new StringBuilder("\n");
        sb.append("RETRIEVAL EVALUATION [macro-avg]\n");
        sb.append("|queries|\t P@1\t P@5\t P@10\t R@1\t R@5\t R@10\tnDCG@10\t MRR\t MAP\t");
        sb.append("\n");
        sb.append(fInt(countExamples())).append("\t");
        sb.append(fDbl(getPrecisionK(1))).append("\t");
        sb.append(fDbl(getPrecisionK(5))).append("\t");
        sb.append(fDbl(getPrecisionK(10))).append("\t");
        sb.append(fDbl(getRecallK(1))).append("\t");
        sb.append(fDbl(getRecallK(5))).append("\t");
        sb.append(fDbl(getRecallK(10))).append("\t");
        sb.append(fDbl(getNDCG(10))).append("\t");
        sb.append(fDbl(getMRR())).append("\t");
        sb.append(fDbl(getMAP())).append("\t");
        sb.append("\n");
        System.out.println(sb.toString());
        return sb.toString();
    }

    static {
        $assertionsDisabled = !RetrievalEvaluation.class.desiredAssertionStatus();
    }
}
