package de.viadee.xai.anchor.algorithm.execution;

import de.viadee.xai.anchor.algorithm.AnchorCandidate;
import de.viadee.xai.anchor.algorithm.ClassificationFunction;
import de.viadee.xai.anchor.algorithm.DataInstance;
import de.viadee.xai.anchor.algorithm.PerturbationFunction;
import de.viadee.xai.anchor.algorithm.execution.sampling.SamplingFunction;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/* loaded from: input_file:de/viadee/xai/anchor/algorithm/execution/BalancedParallelSamplingService.class */
public class BalancedParallelSamplingService<T extends DataInstance<?>> extends ParallelSamplingService<T> {
    private static final long serialVersionUID = 344301140970085409L;
    private final int threadCount;
    private final ExecutorServiceFunction executorServiceFunction;

    /* loaded from: input_file:de/viadee/xai/anchor/algorithm/execution/BalancedParallelSamplingService$BalancedParallelSession.class */
    private class BalancedParallelSession extends ParallelSamplingService<T>.ParallelSession {
        private static final long serialVersionUID = 8982485103898064125L;

        private BalancedParallelSession(int i) {
            super(BalancedParallelSamplingService.this, i);
        }

        @Override // de.viadee.xai.anchor.algorithm.execution.ParallelSamplingService.ParallelSession
        protected Collection<Callable<Object>> createCallables() {
            int threadCount = BalancedParallelSamplingService.this.getThreadCount();
            ArrayList arrayList = new ArrayList();
            int sum = this.samplingCountMap.values().stream().mapToInt(num -> {
                return num.intValue();
            }).sum();
            int i = sum / threadCount;
            int i2 = sum % threadCount;
            Iterator<Map.Entry<AnchorCandidate, Integer>> it = this.samplingCountMap.entrySet().iterator();
            if (!it.hasNext()) {
                return Collections.emptyList();
            }
            Map.Entry<AnchorCandidate, Integer> next = it.next();
            int intValue = next.getValue().intValue();
            for (int i3 = 0; i3 < threadCount; i3++) {
                ArrayList arrayList2 = new ArrayList();
                int i4 = i;
                if (i2 > 0) {
                    i4++;
                    i2--;
                }
                while (i4 > 0) {
                    int min = Math.min(i4, intValue);
                    AnchorCandidate key = next.getKey();
                    intValue -= min;
                    i4 -= min;
                    arrayList2.add(() -> {
                        doSample(key, min);
                    });
                    if (intValue < 1) {
                        if (!it.hasNext()) {
                            break;
                        }
                        next = it.next();
                        intValue = next.getValue().intValue();
                    }
                }
                arrayList.add(Executors.callable(() -> {
                    arrayList2.forEach((v0) -> {
                        v0.run();
                    });
                }));
            }
            return arrayList;
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        if (getExecutorService() != null || this.executorServiceFunction == null) {
            return;
        }
        setExecutorService(this.executorServiceFunction.apply(Integer.valueOf(this.threadCount)));
    }

    public BalancedParallelSamplingService(ClassificationFunction<T> classificationFunction, PerturbationFunction<T> perturbationFunction, ExecutorService executorService, ExecutorServiceSupplier executorServiceSupplier, int i) {
        super(classificationFunction, perturbationFunction, executorService, executorServiceSupplier);
        this.threadCount = i;
        this.executorServiceFunction = null;
    }

    public BalancedParallelSamplingService(ClassificationFunction<T> classificationFunction, PerturbationFunction<T> perturbationFunction, ExecutorService executorService, ExecutorServiceFunction executorServiceFunction, int i) {
        super(classificationFunction, perturbationFunction, executorService, null);
        this.threadCount = i;
        this.executorServiceFunction = executorServiceFunction;
    }

    public BalancedParallelSamplingService(SamplingFunction samplingFunction, ExecutorService executorService, ExecutorServiceSupplier executorServiceSupplier, int i) {
        super(samplingFunction, executorService, executorServiceSupplier);
        this.threadCount = i;
        this.executorServiceFunction = null;
    }

    public BalancedParallelSamplingService(SamplingFunction samplingFunction, ExecutorService executorService, ExecutorServiceFunction executorServiceFunction, int i) {
        super(samplingFunction, executorService, null);
        this.threadCount = i;
        this.executorServiceFunction = executorServiceFunction;
    }

    @Override // de.viadee.xai.anchor.algorithm.execution.ParallelSamplingService, de.viadee.xai.anchor.algorithm.execution.SamplingService
    public SamplingService notifySamplingFunctionChange(SamplingFunction samplingFunction) {
        return new BalancedParallelSamplingService(samplingFunction, getExecutorService(), getExecutorServiceSupplier(), getThreadCount());
    }

    @Override // de.viadee.xai.anchor.algorithm.execution.ParallelSamplingService, de.viadee.xai.anchor.algorithm.execution.SamplingService
    public SamplingSession createSession(int i) {
        return new BalancedParallelSession(i);
    }

    protected int getThreadCount() {
        return this.threadCount;
    }
}
