package de.viadee.xai.anchor.algorithm.global;

import de.viadee.xai.anchor.algorithm.AnchorConstructionBuilder;
import de.viadee.xai.anchor.algorithm.AnchorResult;
import de.viadee.xai.anchor.algorithm.DataInstance;
import de.viadee.xai.anchor.algorithm.execution.ExecutorServiceFunction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.ListIterator;
import java.util.concurrent.ExecutorService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/viadee/xai/anchor/algorithm/global/CoveragePick.class */
public class CoveragePick<T extends DataInstance<?>> extends AbstractGlobalExplainer<T> {
    private static final Logger LOGGER = LoggerFactory.getLogger(CoveragePick.class);
    private final boolean includeTargetValue;

    public CoveragePick(AnchorConstructionBuilder<T> anchorConstructionBuilder, int i, ExecutorService executorService, ExecutorServiceFunction executorServiceFunction) {
        super(anchorConstructionBuilder, i, executorService, executorServiceFunction);
        this.includeTargetValue = false;
    }

    public CoveragePick(boolean z, BatchExplainer<T> batchExplainer, AnchorConstructionBuilder<T> anchorConstructionBuilder) {
        super(batchExplainer, anchorConstructionBuilder);
        this.includeTargetValue = z;
    }

    @Override // de.viadee.xai.anchor.algorithm.global.AbstractGlobalExplainer
    List<AnchorResult<T>> pickExplanations(AnchorResult<T>[] anchorResultArr, int i) {
        ArrayList arrayList = new ArrayList(Arrays.asList(anchorResultArr));
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            ListIterator listIterator = arrayList.listIterator();
            double d = -1.0d;
            int i3 = -1;
            while (listIterator.hasNext()) {
                AnchorResult anchorResult = (AnchorResult) listIterator.next();
                if (anchorResult.getCoverage().doubleValue() > d) {
                    d = anchorResult.getCoverage().doubleValue();
                    i3 = listIterator.previousIndex();
                }
            }
            if (i3 < 0) {
                break;
            }
            AnchorResult anchorResult2 = (AnchorResult) arrayList.remove(i3);
            arrayList2.add(anchorResult2);
            ListIterator listIterator2 = arrayList.listIterator();
            while (listIterator2.hasNext()) {
                AnchorResult anchorResult3 = (AnchorResult) listIterator2.next();
                boolean z = !this.includeTargetValue || anchorResult2.getExplainedInstanceLabel() == anchorResult3.getExplainedInstanceLabel();
                boolean anyMatch = anchorResult3.getCanonicalFeatures().stream().anyMatch(num -> {
                    return anchorResult2.getCanonicalFeatures().contains(num) && anchorResult3.getInstance().getValue(num.intValue()).equals(anchorResult2.getInstance().getValue(num.intValue()));
                });
                if (z && anyMatch) {
                    listIterator2.remove();
                }
            }
        }
        if (this.includeTargetValue) {
            arrayList2.stream().map((v0) -> {
                return v0.getExplainedInstanceLabel();
            }).distinct().forEach(num2 -> {
                LOGGER.info("The returned {} results for label {} exclusively cover a total of {}% of the model's input", new Object[]{Integer.valueOf(arrayList2.size()), num2, (Double) arrayList2.stream().filter(anchorResult4 -> {
                    return anchorResult4.getExplainedInstanceLabel() == num2.intValue();
                }).map((v0) -> {
                    return v0.getCoverage();
                }).reduce((d2, d3) -> {
                    return Double.valueOf(d2.doubleValue() + d3.doubleValue());
                }).orElse(Double.valueOf(0.0d))});
            });
        } else {
            LOGGER.info("The returned {} results exclusively cover a total of {}% of the model's input", Integer.valueOf(arrayList2.size()), (Double) arrayList2.stream().map((v0) -> {
                return v0.getCoverage();
            }).reduce((d2, d3) -> {
                return Double.valueOf(d2.doubleValue() + d3.doubleValue());
            }).orElse(Double.valueOf(0.0d)));
        }
        return arrayList2;
    }
}
