package de.viadee.xai.anchor.algorithm.global;

import de.viadee.xai.anchor.algorithm.AnchorConstructionBuilder;
import de.viadee.xai.anchor.algorithm.AnchorResult;
import de.viadee.xai.anchor.algorithm.DataInstance;
import de.viadee.xai.anchor.algorithm.execution.ExecutorServiceFunction;
import de.viadee.xai.anchor.algorithm.execution.ExecutorServiceSupplier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/viadee/xai/anchor/algorithm/global/SubmodularPick.class */
public class SubmodularPick<T extends DataInstance<?>> extends AbstractGlobalExplainer<T> {
    private static final Logger LOGGER = LoggerFactory.getLogger(SubmodularPick.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:de/viadee/xai/anchor/algorithm/global/SubmodularPick$CreateFeatureToColumnMapResult.class */
    public class CreateFeatureToColumnMapResult {
        final Map<Integer, ?> map;
        final int columnCount;

        /* JADX INFO: Access modifiers changed from: package-private */
        public CreateFeatureToColumnMapResult(Map<Integer, ?> map, int i) {
            this.map = map;
            this.columnCount = i;
        }
    }

    public SubmodularPick(AnchorConstructionBuilder<T> anchorConstructionBuilder, int i, ExecutorService executorService, ExecutorServiceSupplier executorServiceSupplier) {
        super(anchorConstructionBuilder, i, executorService, executorServiceSupplier);
    }

    public SubmodularPick(AnchorConstructionBuilder<T> anchorConstructionBuilder, int i, ExecutorService executorService, ExecutorServiceFunction executorServiceFunction) {
        super(anchorConstructionBuilder, i, executorService, executorServiceFunction);
    }

    public SubmodularPick(BatchExplainer<T> batchExplainer, AnchorConstructionBuilder<T> anchorConstructionBuilder) {
        super(batchExplainer, anchorConstructionBuilder);
    }

    private double[] createColumnImportance(double[][] dArr) {
        double[] dArr2 = new double[dArr[0].length];
        for (int i = 0; i < dArr[0].length; i++) {
            double d = 0.0d;
            for (double[] dArr3 : dArr) {
                d += dArr3[i];
            }
            dArr2[i] = d / dArr.length;
        }
        return dArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v11, types: [de.viadee.xai.anchor.algorithm.AnchorCandidate] */
    private double computeFeatureImportance(AnchorResult<T> anchorResult, int i) {
        AnchorResult<T> anchorResult2 = anchorResult;
        do {
            List<Integer> orderedFeatures = anchorResult2.getOrderedFeatures();
            if (orderedFeatures.get(orderedFeatures.size() - 1).equals(Integer.valueOf(i))) {
                return Math.max(0.0d, Math.min(1.0d, anchorResult2.getAddedPrecision()));
            }
            anchorResult2 = anchorResult2.getParentCandidate();
        } while (anchorResult2 != null);
        throw new RuntimeException("Should not happen - Inconsistent candidate inheritance!");
    }

    @Override // de.viadee.xai.anchor.algorithm.global.AbstractGlobalExplainer
    List<AnchorResult<T>> pickExplanations(AnchorResult<T>[] anchorResultArr, int i) {
        SubmodularPick<T>.CreateFeatureToColumnMapResult createFeatureToColumnMap = createFeatureToColumnMap(anchorResultArr);
        Map<Integer, ?> map = createFeatureToColumnMap.map;
        double[][] dArr = new double[anchorResultArr.length][createFeatureToColumnMap.columnCount];
        for (int i2 = 0; i2 < anchorResultArr.length; i2++) {
            Iterator<Integer> it = anchorResultArr[i2].getOrderedFeatures().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                dArr[i2][getCandidateFeatureIndex(map, anchorResultArr[i2], intValue)] = computeFeatureImportance(anchorResultArr[i2], intValue);
            }
        }
        return greedyPick(i, anchorResultArr, dArr, createColumnImportance(dArr));
    }

    protected SubmodularPick<T>.CreateFeatureToColumnMapResult createFeatureToColumnMap(AnchorResult<T>[] anchorResultArr) {
        HashMap hashMap = new HashMap();
        int i = 0;
        for (AnchorResult<T> anchorResult : anchorResultArr) {
            for (Integer num : anchorResult.getOrderedFeatures()) {
                if (!hashMap.containsKey(num)) {
                    int i2 = i;
                    i++;
                    hashMap.put(num, Integer.valueOf(i2));
                }
            }
        }
        return new CreateFeatureToColumnMapResult(hashMap, hashMap.size());
    }

    protected int getCandidateFeatureIndex(Map<Integer, ?> map, AnchorResult anchorResult, int i) {
        return ((Integer) map.get(Integer.valueOf(i))).intValue();
    }

    private List<AnchorResult<T>> greedyPick(int i, AnchorResult<T>[] anchorResultArr, double[][] dArr, double[] dArr2) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        HashSet hashSet = new HashSet(Arrays.asList(anchorResultArr));
        double d = 0.0d;
        for (int i2 = 0; i2 < Math.min(i, anchorResultArr.length); i2++) {
            double d2 = 0.0d;
            Integer num = null;
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                int indexOf = Arrays.asList(anchorResultArr).indexOf((AnchorResult) it.next());
                double multiply = SubmodularPickUtils.multiply(SubmodularPickUtils.colSum(dArr, linkedHashSet, indexOf), dArr2);
                if (multiply > d2) {
                    num = Integer.valueOf(indexOf);
                    d2 = multiply;
                }
            }
            if (num == null) {
                break;
            }
            LOGGER.info("Adding candidate {} adding coverage of {}, totalling to {}", new Object[]{anchorResultArr[num.intValue()].getCanonicalFeatures(), Double.valueOf(d2 - d), Double.valueOf(d2)});
            d = d2;
            linkedHashSet.add(num);
            hashSet.remove(anchorResultArr[num.intValue()]);
        }
        ArrayList arrayList = new ArrayList();
        Iterator it2 = linkedHashSet.iterator();
        while (it2.hasNext()) {
            arrayList.add(anchorResultArr[((Integer) it2.next()).intValue()]);
        }
        return arrayList;
    }
}
