package de.viadee.xai.anchor.algorithm.exploration;

import de.viadee.xai.anchor.algorithm.AnchorCandidate;
import de.viadee.xai.anchor.algorithm.execution.SamplingService;
import de.viadee.xai.anchor.algorithm.util.KLBernoulliUtils;
import de.viadee.xai.anchor.algorithm.util.MathUtils;
import de.viadee.xai.anchor.algorithm.util.ParameterValidation;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/* loaded from: input_file:de/viadee/xai/anchor/algorithm/exploration/KL_LUCB.class */
public class KL_LUCB implements BestAnchorIdentification {
    private static final long serialVersionUID = -1417031236085364837L;
    private static final int DEFAULT_KL_LUCB_BATCH_SIZE = 100;
    private final int batchSize;

    public KL_LUCB() {
        this(DEFAULT_KL_LUCB_BATCH_SIZE);
    }

    public KL_LUCB(int i) {
        if (!ParameterValidation.isUnsigned(Integer.valueOf(i))) {
            throw new IllegalArgumentException("Batch size must not be negative");
        }
        this.batchSize = i;
    }

    static int[] updateBounds(int i, List<AnchorCandidate> list, double d, int i2, double[] dArr, double[] dArr2) {
        double[] multipleMeans = getMultipleMeans(list);
        int[] argSort = MathUtils.argSort(multipleMeans);
        double computeBeta = KLBernoulliUtils.computeBeta(list.size(), i, d);
        int[] copyOfRange = Arrays.copyOfRange(argSort, multipleMeans.length - i2, multipleMeans.length);
        int[] copyOfRange2 = Arrays.copyOfRange(argSort, 0, multipleMeans.length - i2);
        for (int i3 : copyOfRange2) {
            dArr[i3] = KLBernoulliUtils.dupBernoulli(multipleMeans[i3], computeBeta / list.get(i3).getSampledSize());
        }
        for (int i4 : copyOfRange) {
            dArr2[i4] = KLBernoulliUtils.dlowBernoulli(multipleMeans[i4], computeBeta / list.get(i4).getSampledSize());
        }
        return new int[]{copyOfRange2.length == 0 ? 0 : copyOfRange2[MathUtils.argMax(IntStream.of(copyOfRange2).mapToDouble(i5 -> {
            return dArr[i5];
        }).toArray())], copyOfRange[MathUtils.argMin(IntStream.of(copyOfRange).mapToDouble(i6 -> {
            return dArr2[i6];
        }).toArray())]};
    }

    private static double[] getMultipleMeans(List<AnchorCandidate> list) {
        double[] dArr = new double[list.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = list.get(i).getPrecision();
        }
        return dArr;
    }

    @Override // de.viadee.xai.anchor.algorithm.exploration.BestAnchorIdentification
    public List<AnchorCandidate> identify(List<AnchorCandidate> list, SamplingService samplingService, int i, double d, double d2, int i2) {
        double[] dArr = new double[list.size()];
        double[] dArr2 = new double[list.size()];
        int i3 = 1;
        int[] updateBounds = updateBounds(1, list, d, i2, dArr, dArr2);
        int i4 = updateBounds[0];
        int i5 = updateBounds[1];
        double d3 = dArr[i4];
        double d4 = dArr2[i5];
        while (d3 - d4 > d2) {
            samplingService.createSession(i).registerCandidateEvaluation(list.get(i4), this.batchSize).registerCandidateEvaluation(list.get(i5), this.batchSize).run();
            i3++;
            int[] updateBounds2 = updateBounds(i3, list, d, i2, dArr, dArr2);
            i4 = updateBounds2[0];
            i5 = updateBounds2[1];
            d3 = dArr[i4];
            d4 = dArr2[i5];
        }
        double[] multipleMeans = getMultipleMeans(list);
        IntStream stream = Arrays.stream(Arrays.copyOfRange(MathUtils.argSort(multipleMeans), multipleMeans.length - i2, multipleMeans.length));
        list.getClass();
        return (List) stream.mapToObj(list::get).collect(Collectors.toList());
    }
}
