package ciir.umass.edu.learning.boosting;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:ciir/umass/edu/learning/boosting/RankBoost.class */
public class RankBoost extends Ranker {
    public static int nIteration = 300;
    public static int nThreshold = 10;
    protected double[][][] sweight;
    protected double[][] potential;
    protected List<List<int[]>> sortedSamples;
    protected double[][] thresholds;
    protected int[][] tSortedIdx;
    protected List<RBWeakRanker> wRankers;
    protected List<Double> rWeight;
    protected List<RBWeakRanker> bestModelRankers;
    protected List<Double> bestModelWeights;
    private double R_t;
    private double Z_t;
    private int totalCorrectPairs;

    public RankBoost() {
        this.sweight = null;
        this.potential = null;
        this.sortedSamples = new ArrayList();
        this.thresholds = null;
        this.tSortedIdx = null;
        this.wRankers = null;
        this.rWeight = null;
        this.bestModelRankers = new ArrayList();
        this.bestModelWeights = new ArrayList();
        this.R_t = 0.0d;
        this.Z_t = 1.0d;
        this.totalCorrectPairs = 0;
    }

    public RankBoost(List<RankList> list, int[] iArr, MetricScorer metricScorer) {
        super(list, iArr, metricScorer);
        this.sweight = null;
        this.potential = null;
        this.sortedSamples = new ArrayList();
        this.thresholds = null;
        this.tSortedIdx = null;
        this.wRankers = null;
        this.rWeight = null;
        this.bestModelRankers = new ArrayList();
        this.bestModelWeights = new ArrayList();
        this.R_t = 0.0d;
        this.Z_t = 1.0d;
        this.totalCorrectPairs = 0;
    }

    private int[] reorder(RankList rankList, int i) {
        double[] dArr = new double[rankList.size()];
        for (int i2 = 0; i2 < rankList.size(); i2++) {
            dArr[i2] = rankList.get(i2).getFeatureValue(i);
        }
        return MergeSorter.sort(dArr, false);
    }

    private void updatePotential() {
        for (int i = 0; i < this.samples.size(); i++) {
            RankList rankList = this.samples.get(i);
            for (int i2 = 0; i2 < rankList.size(); i2++) {
                double d = 0.0d;
                for (int i3 = i2 + 1; i3 < rankList.size(); i3++) {
                    d += this.sweight[i][i2][i3];
                }
                for (int i4 = 0; i4 < i2; i4++) {
                    d -= this.sweight[i][i4][i2];
                }
                this.potential[i][i2] = d;
            }
        }
    }

    private RBWeakRanker learnWeakRanker() {
        int i = -1;
        double d = -10.0d;
        double d2 = -1.0d;
        for (int i2 = 0; i2 < this.features.length; i2++) {
            List<int[]> list = this.sortedSamples.get(i2);
            int[] iArr = this.tSortedIdx[i2];
            int[] iArr2 = new int[this.samples.size()];
            for (int i3 = 0; i3 < this.samples.size(); i3++) {
                iArr2[i3] = -1;
            }
            double d3 = 0.0d;
            for (int i4 : iArr) {
                double d4 = this.thresholds[i2][i4];
                for (int i5 = 0; i5 < this.samples.size(); i5++) {
                    RankList rankList = this.samples.get(i5);
                    int[] iArr3 = list.get(i5);
                    for (int i6 = iArr2[i5] + 1; i6 < rankList.size() && rankList.get(iArr3[i6]).getFeatureValue(this.features[i2]) > d4; i6++) {
                        d3 += this.potential[i5][iArr3[i6]];
                        iArr2[i5] = i6;
                    }
                }
                if (d3 > d) {
                    d = d3;
                    d2 = d4;
                    i = this.features[i2];
                }
            }
        }
        if (i == -1) {
            return null;
        }
        this.R_t = this.Z_t * d;
        return new RBWeakRanker(i, d2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v10, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v17, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v30, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v68, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v90, types: [double[], double[][]] */
    @Override // ciir.umass.edu.learning.Ranker
    public void init() {
        PRINT("Initializing... ");
        this.wRankers = new ArrayList();
        this.rWeight = new ArrayList();
        this.totalCorrectPairs = 0;
        for (int i = 0; i < this.samples.size(); i++) {
            this.samples.set(i, this.samples.get(i).getCorrectRanking());
            RankList rankList = this.samples.get(i);
            for (int i2 = 0; i2 < rankList.size() - 1; i2++) {
                for (int size = rankList.size() - 1; size >= i2 + 1 && rankList.get(i2).getLabel() > rankList.get(size).getLabel(); size--) {
                    this.totalCorrectPairs++;
                }
            }
        }
        this.sweight = new double[this.samples.size()];
        for (int i3 = 0; i3 < this.samples.size(); i3++) {
            RankList rankList2 = this.samples.get(i3);
            this.sweight[i3] = new double[rankList2.size()];
            for (int i4 = 0; i4 < rankList2.size() - 1; i4++) {
                this.sweight[i3][i4] = new double[rankList2.size()];
                for (int i5 = i4 + 1; i5 < rankList2.size(); i5++) {
                    if (rankList2.get(i4).getLabel() > rankList2.get(i5).getLabel()) {
                        this.sweight[i3][i4][i5] = 1.0d / this.totalCorrectPairs;
                    } else {
                        this.sweight[i3][i4][i5] = 0.0d;
                    }
                }
            }
        }
        this.potential = new double[this.samples.size()];
        for (int i6 = 0; i6 < this.samples.size(); i6++) {
            this.potential[i6] = new double[this.samples.get(i6).size()];
        }
        if (nThreshold <= 0) {
            int i7 = 0;
            for (int i8 = 0; i8 < this.samples.size(); i8++) {
                i7 += this.samples.get(i8).size();
            }
            this.thresholds = new double[this.features.length];
            for (int i9 = 0; i9 < this.features.length; i9++) {
                this.thresholds[i9] = new double[i7];
            }
            int i10 = 0;
            for (int i11 = 0; i11 < this.samples.size(); i11++) {
                RankList rankList3 = this.samples.get(i11);
                for (int i12 = 0; i12 < rankList3.size(); i12++) {
                    for (int i13 = 0; i13 < this.features.length; i13++) {
                        this.thresholds[i13][i10] = rankList3.get(i12).getFeatureValue(this.features[i13]);
                    }
                    i10++;
                }
            }
        } else {
            double[] dArr = new double[this.features.length];
            double[] dArr2 = new double[this.features.length];
            for (int i14 = 0; i14 < this.features.length; i14++) {
                dArr[i14] = -1000000.0d;
                dArr2[i14] = 1000000.0d;
            }
            for (int i15 = 0; i15 < this.samples.size(); i15++) {
                RankList rankList4 = this.samples.get(i15);
                for (int i16 = 0; i16 < rankList4.size(); i16++) {
                    for (int i17 = 0; i17 < this.features.length; i17++) {
                        double featureValue = rankList4.get(i16).getFeatureValue(this.features[i17]);
                        if (featureValue > dArr[i17]) {
                            dArr[i17] = featureValue;
                        }
                        if (featureValue < dArr2[i17]) {
                            dArr2[i17] = featureValue;
                        }
                    }
                }
            }
            this.thresholds = new double[this.features.length];
            for (int i18 = 0; i18 < this.features.length; i18++) {
                double abs = Math.abs(dArr[i18] - dArr2[i18]) / nThreshold;
                this.thresholds[i18] = new double[nThreshold + 1];
                this.thresholds[i18][0] = dArr[i18];
                for (int i19 = 1; i19 < nThreshold; i19++) {
                    this.thresholds[i18][i19] = this.thresholds[i18][i19 - 1] - abs;
                }
                this.thresholds[i18][nThreshold] = dArr2[i18] - 1.0E8d;
            }
        }
        this.tSortedIdx = new int[this.features.length];
        for (int i20 = 0; i20 < this.features.length; i20++) {
            this.tSortedIdx[i20] = MergeSorter.sort(this.thresholds[i20], false);
        }
        for (int i21 = 0; i21 < this.features.length; i21++) {
            ArrayList arrayList = new ArrayList();
            for (int i22 = 0; i22 < this.samples.size(); i22++) {
                arrayList.add(reorder(this.samples.get(i22), this.features[i21]));
            }
            this.sortedSamples.add(arrayList);
        }
        PRINTLN("[Done]");
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // ciir.umass.edu.learning.Ranker
    public void learn() {
        PRINTLN("------------------------------------------");
        PRINTLN("Training starts...");
        PRINTLN("--------------------------------------------------------------------");
        PRINTLN(new int[]{7, 8, 9, 9, 9, 9}, new String[]{"#iter", "Sel. F.", "Threshold", "Error", this.scorer.name() + "-T", this.scorer.name() + "-V"});
        PRINTLN("--------------------------------------------------------------------");
        for (int i = 1; i <= nIteration; i++) {
            updatePotential();
            RBWeakRanker learnWeakRanker = learnWeakRanker();
            if (learnWeakRanker == null) {
                break;
            }
            double ln = 0.5d * SimpleMath.ln((this.Z_t + this.R_t) / (this.Z_t - this.R_t));
            this.wRankers.add(learnWeakRanker);
            this.rWeight.add(Double.valueOf(ln));
            this.Z_t = 0.0d;
            for (int i2 = 0; i2 < this.samples.size(); i2++) {
                RankList rankList = this.samples.get(i2);
                double[] dArr = new double[rankList.size()];
                for (int i3 = 0; i3 < rankList.size() - 1; i3++) {
                    dArr[i3] = new double[rankList.size()];
                    for (int i4 = i3 + 1; i4 < rankList.size(); i4++) {
                        dArr[i3][i4] = this.sweight[i2][i3][i4] * Math.exp(ln * (learnWeakRanker.score(rankList.get(i4)) - learnWeakRanker.score(rankList.get(i3))));
                        this.Z_t += dArr[i3][i4];
                    }
                }
                this.sweight[i2] = dArr;
            }
            PRINT(new int[]{7, 8, 9, 9}, new String[]{i + "", learnWeakRanker.getFid() + "", SimpleMath.round(learnWeakRanker.getThreshold(), 4) + "", SimpleMath.round(this.R_t, 4) + ""});
            if (i % 1 == 0) {
                PRINT(new int[]{9}, new String[]{SimpleMath.round(this.scorer.score(rank(this.samples)), 4) + ""});
                if (this.validationSamples != null) {
                    double score = this.scorer.score(rank(this.validationSamples));
                    if (score > this.bestScoreOnValidationData) {
                        this.bestScoreOnValidationData = score;
                        this.bestModelRankers.clear();
                        this.bestModelRankers.addAll(this.wRankers);
                        this.bestModelWeights.clear();
                        this.bestModelWeights.addAll(this.rWeight);
                    }
                    PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4) + ""});
                }
            }
            PRINTLN("");
            for (int i5 = 0; i5 < this.samples.size(); i5++) {
                RankList rankList2 = this.samples.get(i5);
                for (int i6 = 0; i6 < rankList2.size() - 1; i6++) {
                    for (int i7 = i6 + 1; i7 < rankList2.size(); i7++) {
                        double[] dArr2 = this.sweight[i5][i6];
                        int i8 = i7;
                        dArr2[i8] = dArr2[i8] / this.Z_t;
                    }
                }
            }
            System.gc();
        }
        if (this.validationSamples != null && this.bestModelRankers.size() > 0) {
            this.wRankers.clear();
            this.rWeight.clear();
            this.wRankers.addAll(this.bestModelRankers);
            this.rWeight.addAll(this.bestModelWeights);
        }
        this.scoreOnTrainingData = SimpleMath.round(this.scorer.score(rank(this.samples)), 4);
        PRINTLN("--------------------------------------------------------------------");
        PRINTLN("Finished sucessfully.");
        PRINTLN(this.scorer.name() + " on training data: " + this.scoreOnTrainingData);
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(rank(this.validationSamples));
            PRINTLN(this.scorer.name() + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        PRINTLN("---------------------------------");
    }

    @Override // ciir.umass.edu.learning.Ranker
    public double eval(DataPoint dataPoint) {
        double d = 0.0d;
        for (int i = 0; i < this.wRankers.size(); i++) {
            d += this.rWeight.get(i).doubleValue() * this.wRankers.get(i).score(dataPoint);
        }
        return d;
    }

    @Override // ciir.umass.edu.learning.Ranker
    public Ranker createNew() {
        return new RankBoost();
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String toString() {
        String str = "";
        int i = 0;
        while (i < this.wRankers.size()) {
            str = str + this.wRankers.get(i).toString() + ":" + this.rWeight.get(i) + (i == this.rWeight.size() - 1 ? "" : " ");
            i++;
        }
        return str;
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String model() {
        return ((("## " + name() + "\n") + "## Iteration = " + nIteration + "\n") + "## No. of threshold candidates = " + nThreshold + "\n") + toString();
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void loadFromString(String str) {
        String str2;
        try {
            BufferedReader bufferedReader = new BufferedReader(new StringReader(str));
            while (true) {
                String readLine = bufferedReader.readLine();
                str2 = readLine;
                if (readLine == null) {
                    break;
                }
                str2 = str2.trim();
                if (str2.length() != 0 && str2.indexOf("##") != 0) {
                    break;
                }
            }
            bufferedReader.close();
            this.rWeight = new ArrayList();
            this.wRankers = new ArrayList();
            int lastIndexOf = str2.lastIndexOf("#");
            if (lastIndexOf != -1) {
                str2 = str2.substring(0, lastIndexOf).trim();
            }
            String[] split = str2.split(" ");
            for (int i = 0; i < split.length; i++) {
                split[i] = split[i].trim();
                if (split[i].compareTo("") != 0) {
                    String[] split2 = split[i].split(":");
                    int parseInt = Integer.parseInt(split2[0]);
                    double parseDouble = Double.parseDouble(split2[1]);
                    this.rWeight.add(Double.valueOf(Double.parseDouble(split2[2])));
                    this.wRankers.add(new RBWeakRanker(parseInt, parseDouble));
                }
            }
            this.features = new int[this.rWeight.size()];
            for (int i2 = 0; i2 < this.rWeight.size(); i2++) {
                this.features[i2] = this.wRankers.get(i2).getFid();
            }
        } catch (Exception e) {
            throw RankLibError.create("Error in RankBoost::load(): ", e);
        }
    }

    @Override // ciir.umass.edu.learning.Ranker
    public void printParameters() {
        PRINTLN("No. of rounds: " + nIteration);
        PRINTLN("No. of threshold candidates: " + nThreshold);
    }

    @Override // ciir.umass.edu.learning.Ranker
    public String name() {
        return "RankBoost";
    }
}
