package org.nd4j.linalg.api.parallel;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.Future;
import java.util.concurrent.RunnableFuture;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.parallel.TaskCreator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/DefaultParallelExecutioner.class */
public class DefaultParallelExecutioner implements ParallelExecutioner {
    private ExecutorService executorService;
    private ForkJoinPool forkJoinPool;
    private static Logger log = LoggerFactory.getLogger(DefaultParallelExecutioner.class);

    public DefaultParallelExecutioner(ForkJoinPool forkJoinPool) {
        this.forkJoinPool = forkJoinPool;
    }

    public DefaultParallelExecutioner(ExecutorService executorService) {
        this.executorService = executorService;
    }

    public DefaultParallelExecutioner() {
        this(new ForkJoinPool(Runtime.getRuntime().availableProcessors(), ForkJoinPool.defaultForkJoinWorkerThreadFactory, null, false));
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public INDArray execBasedOnArraysAlongDimension(INDArray iNDArray, Accumulation accumulation, OpExecutioner opExecutioner, int... iArr) {
        INDArray create = Nd4j.create(ArrayUtil.removeIndex(accumulation.x().shape(), iArr));
        if (this.forkJoinPool != null) {
            List<ForkJoinTask<INDArray>> parititonForkJoinBasedOnTensorsAlongDimension = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(iNDArray, accumulation, opExecutioner, create, iArr);
            ArrayList arrayList = new ArrayList();
            Iterator<ForkJoinTask<INDArray>> it = parititonForkJoinBasedOnTensorsAlongDimension.iterator();
            while (it.hasNext()) {
                arrayList.add(this.forkJoinPool.submit(it.next()));
            }
            Iterator<ForkJoinTask<INDArray>> it2 = parititonForkJoinBasedOnTensorsAlongDimension.iterator();
            while (it2.hasNext()) {
                try {
                    it2.next().get();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } catch (ExecutionException e2) {
                    e2.printStackTrace();
                }
            }
        } else {
            Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnTensorsAlongDimension = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(iNDArray, accumulation, opExecutioner, iArr);
            ArrayList arrayList2 = new ArrayList();
            Iterator it3 = ((List) parititonRunnablesBasedOnTensorsAlongDimension.getFirst()).iterator();
            while (it3.hasNext()) {
                arrayList2.add((RunnableFuture) this.executorService.submit((Runnable) it3.next()));
            }
            try {
                ((CountDownLatch) parititonRunnablesBasedOnTensorsAlongDimension.getSecond()).await();
            } catch (InterruptedException e3) {
                e3.printStackTrace();
            }
        }
        return create;
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public void execBasedOnArraysAlongDimension(INDArray iNDArray, Op op, OpExecutioner opExecutioner, int... iArr) {
        if (this.forkJoinPool != null) {
            Pair<CountDownLatch, List<ForkJoinTask<INDArray>>> parititonForkJoinBasedOnTensorsAlongDimension = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(iNDArray, op, opExecutioner, iArr);
            ArrayList arrayList = new ArrayList();
            Iterator it = ((List) parititonForkJoinBasedOnTensorsAlongDimension.getSecond()).iterator();
            while (it.hasNext()) {
                arrayList.add(this.forkJoinPool.submit((ForkJoinTask) it.next()));
            }
            try {
                ((CountDownLatch) parititonForkJoinBasedOnTensorsAlongDimension.getFirst()).await();
                return;
            } catch (InterruptedException e) {
                e.printStackTrace();
                return;
            }
        }
        Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnTensorsAlongDimension = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(iNDArray, op, opExecutioner, iArr);
        ArrayList arrayList2 = new ArrayList();
        Iterator it2 = ((List) parititonRunnablesBasedOnTensorsAlongDimension.getFirst()).iterator();
        while (it2.hasNext()) {
            arrayList2.add((RunnableFuture) this.executorService.submit((Runnable) it2.next()));
        }
        try {
            ((CountDownLatch) parititonRunnablesBasedOnTensorsAlongDimension.getSecond()).await();
        } catch (InterruptedException e2) {
            e2.printStackTrace();
        }
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public void execBasedOnSlices(INDArray iNDArray, Op op, OpExecutioner opExecutioner) {
        if (this.forkJoinPool != null) {
            Pair<List<ForkJoinTask<INDArray>>, CountDownLatch> parititonForkJoinBasedOnSlices = TaskCreator.parititonForkJoinBasedOnSlices(iNDArray, op, opExecutioner);
            Iterator it = ((List) parititonForkJoinBasedOnSlices.getFirst()).iterator();
            while (it.hasNext()) {
                this.forkJoinPool.execute((ForkJoinTask<?>) it.next());
            }
            try {
                ((CountDownLatch) parititonForkJoinBasedOnSlices.getValue()).await();
                return;
            } catch (InterruptedException e) {
                e.printStackTrace();
                return;
            }
        }
        Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnSlices = TaskCreator.parititonRunnablesBasedOnSlices(iNDArray, op, opExecutioner);
        ArrayList arrayList = new ArrayList();
        Iterator it2 = ((List) parititonRunnablesBasedOnSlices.getFirst()).iterator();
        while (it2.hasNext()) {
            arrayList.add((RunnableFuture) this.executorService.submit((Runnable) it2.next()));
        }
        try {
            ((CountDownLatch) parititonRunnablesBasedOnSlices.getSecond()).await();
        } catch (InterruptedException e2) {
            e2.printStackTrace();
        }
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public void execBasedOnArraysAlongDimension(INDArray iNDArray, TaskCreator.INDArrayTask iNDArrayTask, int... iArr) {
        if (this.forkJoinPool != null) {
            Pair<List<ForkJoinTask<INDArray>>, CountDownLatch> parititonForkJoinBasedOnTensorsAlongDimension = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(iNDArray, iNDArrayTask, iArr);
            Iterator it = ((List) parititonForkJoinBasedOnTensorsAlongDimension.getFirst()).iterator();
            while (it.hasNext()) {
                this.forkJoinPool.submit((ForkJoinTask) it.next());
            }
            try {
                ((CountDownLatch) parititonForkJoinBasedOnTensorsAlongDimension.getSecond()).await();
                return;
            } catch (InterruptedException e) {
                e.printStackTrace();
                return;
            }
        }
        List<Runnable> parititonRunnablesBasedOnTensorsAlongDimension = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(iNDArray, iNDArrayTask, iArr);
        ArrayList arrayList = new ArrayList();
        Iterator<Runnable> it2 = parititonRunnablesBasedOnTensorsAlongDimension.iterator();
        while (it2.hasNext()) {
            arrayList.add((RunnableFuture) this.executorService.submit(it2.next()));
        }
        Iterator it3 = arrayList.iterator();
        while (it3.hasNext()) {
            try {
                ((RunnableFuture) it3.next()).get();
            } catch (InterruptedException e2) {
                e2.printStackTrace();
            } catch (ExecutionException e3) {
                e3.printStackTrace();
            }
        }
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public void execBasedOnArraysAlongDimension(INDArray[] iNDArrayArr, TaskCreator.INDArrayTask iNDArrayTask, int... iArr) {
        if (this.forkJoinPool != null) {
            Pair<List<ForkJoinTask<INDArray[]>>, CountDownLatch> parititonForkJoinBasedOnTensorsAlongDimension = TaskCreator.parititonForkJoinBasedOnTensorsAlongDimension(iNDArrayArr, iNDArrayTask, iArr);
            Iterator it = ((List) parititonForkJoinBasedOnTensorsAlongDimension.getFirst()).iterator();
            while (it.hasNext()) {
                this.forkJoinPool.execute((ForkJoinTask<?>) it.next());
            }
            try {
                ((CountDownLatch) parititonForkJoinBasedOnTensorsAlongDimension.getSecond()).await();
                return;
            } catch (InterruptedException e) {
                e.printStackTrace();
                return;
            }
        }
        List<Runnable> parititonRunnablesBasedOnTensorsAlongDimension = TaskCreator.parititonRunnablesBasedOnTensorsAlongDimension(iNDArrayArr, iNDArrayTask, iArr);
        ArrayList arrayList = new ArrayList();
        Iterator<Runnable> it2 = parititonRunnablesBasedOnTensorsAlongDimension.iterator();
        while (it2.hasNext()) {
            arrayList.add((RunnableFuture) this.executorService.submit(it2.next()));
        }
        Iterator it3 = arrayList.iterator();
        while (it3.hasNext()) {
            try {
                ((RunnableFuture) it3.next()).get();
            } catch (InterruptedException e2) {
                e2.printStackTrace();
            } catch (ExecutionException e3) {
                e3.printStackTrace();
            }
        }
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public void execBasedOnSlices(INDArray iNDArray, TaskCreator.INDArrayTask iNDArrayTask) {
        if (this.forkJoinPool != null) {
            Pair<List<ForkJoinTask<INDArray>>, CountDownLatch> parititonForkJoinBasedOnSlices = TaskCreator.parititonForkJoinBasedOnSlices(iNDArray, iNDArrayTask);
            Iterator it = ((List) parititonForkJoinBasedOnSlices.getFirst()).iterator();
            while (it.hasNext()) {
                this.forkJoinPool.execute((ForkJoinTask<?>) it.next());
            }
            try {
                ((CountDownLatch) parititonForkJoinBasedOnSlices.getSecond()).await();
                return;
            } catch (InterruptedException e) {
                e.printStackTrace();
                return;
            }
        }
        Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnSlices = TaskCreator.parititonRunnablesBasedOnSlices(iNDArray, iNDArrayTask);
        ArrayList arrayList = new ArrayList();
        Iterator it2 = ((List) parititonRunnablesBasedOnSlices.getFirst()).iterator();
        while (it2.hasNext()) {
            arrayList.add((RunnableFuture) this.executorService.submit((Runnable) it2.next()));
        }
        try {
            ((CountDownLatch) parititonRunnablesBasedOnSlices.getSecond()).await();
        } catch (InterruptedException e2) {
            e2.printStackTrace();
        }
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public Future exec(Runnable runnable) {
        if (this.executorService == null) {
            log.debug("Initializing parallel executioner executor");
            this.executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        }
        return this.executorService.submit(runnable);
    }

    @Override // org.nd4j.linalg.api.parallel.ParallelExecutioner
    public <T> void exec(ForkJoinTask<T> forkJoinTask) {
        if (this.forkJoinPool == null) {
            log.debug("Initializing fork join parallel executor");
            this.forkJoinPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors());
        }
        this.forkJoinPool.execute((ForkJoinTask<?>) forkJoinTask);
    }
}
