package net.imglib2.parallel;

import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.test.RandomImgs;
import net.imglib2.type.numeric.integer.IntType;
import net.imglib2.view.Views;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.OptionsBuilder;

@Warmup(iterations = 5, time = 200, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 5, time = 200, timeUnit = TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Fork(1)
@BenchmarkMode({Mode.AverageTime})
/* loaded from: input_file:net/imglib2/parallel/ParallelizationBenchmark.class */
public class ParallelizationBenchmark {
    private final RandomAccessibleInterval<IntType> image = RandomImgs.seed(42).nextImage(new IntType(), new long[]{100, 500, 500});

    private static long calculateSum(RandomAccessibleInterval<IntType> randomAccessibleInterval) {
        if (randomAccessibleInterval.numDimensions() <= 1) {
            return simpleCalculateSum(randomAccessibleInterval);
        }
        List<RandomAccessibleInterval<IntType>> slices = slices(randomAccessibleInterval);
        AtomicLong atomicLong = new AtomicLong();
        Parallelization.getTaskExecutor().forEach(slices, randomAccessibleInterval2 -> {
            atomicLong.addAndGet(calculateSum(randomAccessibleInterval2));
        });
        return atomicLong.get();
    }

    private static long simpleCalculateSum(RandomAccessibleInterval<IntType> randomAccessibleInterval) {
        long j = 0;
        while (Views.iterable(randomAccessibleInterval).iterator().hasNext()) {
            j += ((IntType) r0.next()).getInteger();
        }
        return j;
    }

    private static List<RandomAccessibleInterval<IntType>> slices(RandomAccessibleInterval<IntType> randomAccessibleInterval) {
        int numDimensions = randomAccessibleInterval.numDimensions() - 1;
        return (List) LongStream.rangeClosed(randomAccessibleInterval.min(numDimensions), randomAccessibleInterval.max(numDimensions)).mapToObj(j -> {
            return Views.hyperSlice(randomAccessibleInterval, numDimensions, j);
        }).collect(Collectors.toList());
    }

    @Benchmark
    public Long fixedThreadPool() {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        Long l = (Long) Parallelization.runWithExecutor(newFixedThreadPool, () -> {
            return Long.valueOf(calculateSum(this.image));
        });
        newFixedThreadPool.shutdown();
        return l;
    }

    @Benchmark
    public Long twoThreadsForkJoinPool() {
        ForkJoinPool forkJoinPool = new ForkJoinPool(2);
        Long l = (Long) Parallelization.runWithExecutor(forkJoinPool, () -> {
            return Long.valueOf(calculateSum(this.image));
        });
        forkJoinPool.shutdown();
        return l;
    }

    @Benchmark
    public Long multiThreaded() {
        return (Long) Parallelization.runMultiThreaded(() -> {
            return Long.valueOf(calculateSum(this.image));
        });
    }

    @Benchmark
    public Long singleThreaded() {
        return (Long) Parallelization.runSingleThreaded(() -> {
            return Long.valueOf(calculateSum(this.image));
        });
    }

    @Benchmark
    public Long defaultBehavior() {
        return Long.valueOf(calculateSum(this.image));
    }

    @Benchmark
    public Long singleThreadedBaseline() {
        return Long.valueOf(calculateSum(this.image));
    }

    public static void main(String... strArr) throws RunnerException {
        new Runner(new OptionsBuilder().include(ParallelizationBenchmark.class.getName()).build()).run();
    }
}
