package org.nd4j.linalg.api.parallel;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ForkJoinTask;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/TaskCreator.class */
public class TaskCreator {

    /* loaded from: input_file:org/nd4j/linalg/api/parallel/TaskCreator$AccumulationINDArrayTask.class */
    public static class AccumulationINDArrayTask implements INDArrayTask {
        private Accumulation op;
        private OpExecutioner opExecutioner;
        private int slice;
        private int[] dimension;
        private INDArray retArray;

        public AccumulationINDArrayTask(Accumulation accumulation, OpExecutioner opExecutioner, INDArray iNDArray) {
            this.slice = -1;
            this.op = accumulation;
            this.opExecutioner = opExecutioner;
            this.retArray = iNDArray;
        }

        public AccumulationINDArrayTask(Accumulation accumulation, OpExecutioner opExecutioner, INDArray iNDArray, int i) {
            this.slice = -1;
            this.op = accumulation;
            this.opExecutioner = opExecutioner;
            this.slice = i;
            this.retArray = iNDArray;
        }

        public AccumulationINDArrayTask(Accumulation accumulation, OpExecutioner opExecutioner, int i, INDArray iNDArray, int[] iArr) {
            this.slice = -1;
            this.op = accumulation;
            this.opExecutioner = opExecutioner;
            this.slice = i;
            this.dimension = iArr;
            this.retArray = iNDArray;
        }

        @Override // org.nd4j.linalg.api.parallel.TaskCreator.INDArrayTask
        public void perform(INDArray... iNDArrayArr) {
            if (this.slice >= 0 && this.dimension == null) {
                this.retArray.putScalar(this.slice, this.opExecutioner.execAndReturn((Accumulation) this.op.opForDimension(this.slice, 0)).currentResult().doubleValue());
            } else {
                if (this.dimension == null) {
                    this.opExecutioner.exec(this.op);
                    return;
                }
                this.retArray.putScalar(this.slice, this.opExecutioner.execAndReturn((Accumulation) this.op.opForDimension(this.slice, this.dimension)).currentResult().doubleValue());
            }
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/api/parallel/TaskCreator$ForkJoinArrayINDArrayTask.class */
    public static class ForkJoinArrayINDArrayTask extends ForkJoinTask<INDArray[]> {
        protected INDArray[] arr;
        private INDArrayTask task;
        private CountDownLatch latch;

        public ForkJoinArrayINDArrayTask(INDArray[] iNDArrayArr, INDArrayTask iNDArrayTask, CountDownLatch countDownLatch) {
            this.arr = iNDArrayArr;
            this.task = iNDArrayTask;
            this.latch = countDownLatch;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.ForkJoinTask
        public INDArray[] getRawResult() {
            return this.arr;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // java.util.concurrent.ForkJoinTask
        public void setRawResult(INDArray[] iNDArrayArr) {
            this.arr = iNDArrayArr;
        }

        @Override // java.util.concurrent.ForkJoinTask
        protected boolean exec() {
            this.task.perform(this.arr);
            this.latch.countDown();
            return true;
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/api/parallel/TaskCreator$ForkJoinINDArrayTask.class */
    public static class ForkJoinINDArrayTask extends ForkJoinTask<INDArray> {
        protected INDArray arr;
        private INDArrayTask task;
        private CountDownLatch latch;

        public ForkJoinINDArrayTask(INDArray iNDArray, INDArrayTask iNDArrayTask, CountDownLatch countDownLatch) {
            this.arr = iNDArray;
            this.task = iNDArrayTask;
            this.latch = countDownLatch;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.ForkJoinTask
        public INDArray getRawResult() {
            return this.arr;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // java.util.concurrent.ForkJoinTask
        public void setRawResult(INDArray iNDArray) {
            this.arr = iNDArray;
        }

        @Override // java.util.concurrent.ForkJoinTask
        protected boolean exec() {
            this.task.perform(this.arr);
            this.latch.countDown();
            return true;
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/api/parallel/TaskCreator$INDArrayTask.class */
    public interface INDArrayTask {
        void perform(INDArray... iNDArrayArr);
    }

    /* loaded from: input_file:org/nd4j/linalg/api/parallel/TaskCreator$OpINDArrayTask.class */
    public static class OpINDArrayTask implements INDArrayTask {
        private Op op;
        private OpExecutioner opExecutioner;
        private int slice;
        private int[] dimension;
        private CountDownLatch countDownLatch;

        public OpINDArrayTask(Op op, OpExecutioner opExecutioner, CountDownLatch countDownLatch) {
            this.slice = -1;
            this.op = op;
            this.opExecutioner = opExecutioner;
            this.countDownLatch = countDownLatch;
        }

        public OpINDArrayTask(Op op, OpExecutioner opExecutioner, int i, CountDownLatch countDownLatch) {
            this.slice = -1;
            this.op = op;
            this.opExecutioner = opExecutioner;
            this.slice = i;
            this.countDownLatch = countDownLatch;
        }

        public OpINDArrayTask(Op op, OpExecutioner opExecutioner, int i, int[] iArr, CountDownLatch countDownLatch) {
            this.slice = -1;
            this.op = op;
            this.opExecutioner = opExecutioner;
            this.slice = i;
            this.dimension = iArr;
            this.countDownLatch = countDownLatch;
        }

        @Override // org.nd4j.linalg.api.parallel.TaskCreator.INDArrayTask
        public void perform(INDArray... iNDArrayArr) {
            if (this.slice >= 0 && this.dimension == null) {
                Op opForDimension = this.op.opForDimension(this.slice, 0);
                this.opExecutioner.exec(opForDimension);
                if (this.op instanceof TransformOp) {
                    ((TransformOp) this.op).z().tensorAlongDimension(this.slice, 0).assign(((TransformOp) opForDimension).z());
                }
            } else if (this.dimension != null) {
                Op opForDimension2 = this.op.opForDimension(this.slice, this.dimension);
                this.opExecutioner.exec(opForDimension2);
                if (this.op instanceof TransformOp) {
                    ((TransformOp) this.op).z().tensorAlongDimension(this.slice, this.dimension).assign(((TransformOp) opForDimension2).z());
                }
            } else {
                this.opExecutioner.exec(this.op);
            }
            if (this.countDownLatch != null) {
                this.countDownLatch.countDown();
            }
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/api/parallel/TaskCreator$RunnableINDArrayTask.class */
    public static class RunnableINDArrayTask implements Runnable {
        private INDArray arr;
        private INDArrayTask task;

        public RunnableINDArrayTask(INDArray iNDArray, INDArrayTask iNDArrayTask) {
            this.arr = iNDArray;
            this.task = iNDArrayTask;
        }

        @Override // java.lang.Runnable
        public void run() {
            this.task.perform(this.arr);
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/api/parallel/TaskCreator$RunnableMultipleINDArrayTask.class */
    public static class RunnableMultipleINDArrayTask implements Runnable {
        private INDArray[] arr;
        private INDArrayTask task;
        private CountDownLatch latch;

        public RunnableMultipleINDArrayTask(INDArray[] iNDArrayArr, INDArrayTask iNDArrayTask, CountDownLatch countDownLatch) {
            this.arr = iNDArrayArr;
            this.task = iNDArrayTask;
            this.latch = countDownLatch;
        }

        @Override // java.lang.Runnable
        public void run() {
            this.task.perform(this.arr);
            this.latch.countDown();
        }
    }

    public static Pair<List<ForkJoinTask<INDArray>>, CountDownLatch> parititonForkJoinBasedOnSlices(INDArray iNDArray, Op op, OpExecutioner opExecutioner) {
        ArrayList arrayList = new ArrayList();
        CountDownLatch countDownLatch = new CountDownLatch(iNDArray.slices());
        for (int i = 0; i < iNDArray.slices(); i++) {
            arrayList.add(new ForkJoinINDArrayTask(iNDArray.slice(i), new OpINDArrayTask(op, opExecutioner, i, null), countDownLatch));
        }
        return new Pair<>(arrayList, countDownLatch);
    }

    public static Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnSlices(INDArray iNDArray, Op op, OpExecutioner opExecutioner) {
        ArrayList arrayList = new ArrayList();
        CountDownLatch countDownLatch = new CountDownLatch(iNDArray.slices());
        for (int i = 0; i < iNDArray.slices(); i++) {
            arrayList.add(new RunnableINDArrayTask(iNDArray.slice(i), new OpINDArrayTask(op, opExecutioner, i, countDownLatch)));
        }
        return new Pair<>(arrayList, countDownLatch);
    }

    public static Pair<List<ForkJoinTask<INDArray[]>>, CountDownLatch> parititonForkJoinBasedOnTensorsAlongDimension(INDArray[] iNDArrayArr, Op op, OpExecutioner opExecutioner, int... iArr) {
        ArrayList arrayList = new ArrayList();
        int tensorssAlongDimension = iNDArrayArr[0].tensorssAlongDimension(iArr);
        for (int i = 1; i < iNDArrayArr.length; i++) {
            if (iNDArrayArr[i].tensorssAlongDimension(iArr) != tensorssAlongDimension) {
                throw new IllegalArgumentException("Unable to parallellize operations with unequal number of tenosrs along dimension");
            }
        }
        CountDownLatch countDownLatch = new CountDownLatch(tensorssAlongDimension);
        for (int i2 = 0; i2 < tensorssAlongDimension; i2++) {
            INDArray[] iNDArrayArr2 = new INDArray[iNDArrayArr.length];
            for (int i3 = 0; i3 < iNDArrayArr2.length; i3++) {
                iNDArrayArr2[i3] = iNDArrayArr[i3].tensorAlongDimension(i2, iArr);
            }
            arrayList.add(new ForkJoinArrayINDArrayTask(iNDArrayArr2, new OpINDArrayTask(op, opExecutioner, i2, iArr, null), countDownLatch));
        }
        return new Pair<>(arrayList, countDownLatch);
    }

    public static List<Runnable> parititonRunnablesBasedOnTensorsAlongDimension(INDArray[] iNDArrayArr, INDArrayTask iNDArrayTask, int... iArr) {
        int tensorssAlongDimension = iNDArrayArr[0].tensorssAlongDimension(iArr);
        for (int i = 1; i < iNDArrayArr.length; i++) {
            if (iNDArrayArr[i].tensorssAlongDimension(iArr) != tensorssAlongDimension) {
                throw new IllegalArgumentException("Unable to parallellize operations with unequal number of tenosrs along dimension");
            }
        }
        ArrayList arrayList = new ArrayList();
        CountDownLatch countDownLatch = new CountDownLatch(tensorssAlongDimension);
        for (int i2 = 0; i2 < tensorssAlongDimension; i2++) {
            INDArray[] iNDArrayArr2 = new INDArray[iNDArrayArr.length];
            for (int i3 = 0; i3 < iNDArrayArr2.length; i3++) {
                iNDArrayArr2[i3] = iNDArrayArr[i3].tensorAlongDimension(i2, iArr);
            }
            arrayList.add(new RunnableMultipleINDArrayTask(iNDArrayArr2, iNDArrayTask, countDownLatch));
        }
        return arrayList;
    }

    public static List<ForkJoinTask<INDArray>> parititonForkJoinBasedOnTensorsAlongDimension(INDArray iNDArray, Accumulation accumulation, OpExecutioner opExecutioner, INDArray iNDArray2, int... iArr) {
        ArrayList arrayList = new ArrayList();
        int tensorssAlongDimension = iNDArray.tensorssAlongDimension(iArr);
        CountDownLatch countDownLatch = new CountDownLatch(tensorssAlongDimension);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            arrayList.add(new ForkJoinINDArrayTask(iNDArray.tensorAlongDimension(i, iArr), new AccumulationINDArrayTask(accumulation, opExecutioner, i, iNDArray2, iArr), countDownLatch));
        }
        return arrayList;
    }

    public static Pair<CountDownLatch, List<ForkJoinTask<INDArray>>> parititonForkJoinBasedOnTensorsAlongDimension(INDArray iNDArray, Op op, OpExecutioner opExecutioner, int... iArr) {
        ArrayList arrayList = new ArrayList();
        int tensorssAlongDimension = iNDArray.tensorssAlongDimension(iArr);
        CountDownLatch countDownLatch = new CountDownLatch(tensorssAlongDimension);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            arrayList.add(new ForkJoinINDArrayTask(iNDArray.tensorAlongDimension(i, iArr), new OpINDArrayTask(op, opExecutioner, i, iArr, countDownLatch), countDownLatch));
        }
        return new Pair<>(new CountDownLatch(arrayList.size()), arrayList);
    }

    public static List<Runnable> parititonRunnablesBasedOnTensorsAlongDimension(INDArray iNDArray, INDArrayTask iNDArrayTask, int... iArr) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < iNDArray.tensorssAlongDimension(iArr); i++) {
            arrayList.add(new RunnableINDArrayTask(iNDArray.tensorAlongDimension(i, iArr), iNDArrayTask));
        }
        return arrayList;
    }

    public static Pair<List<ForkJoinTask<INDArray[]>>, CountDownLatch> parititonForkJoinBasedOnSlices(INDArray[] iNDArrayArr, INDArrayTask iNDArrayTask) {
        int slices = iNDArrayArr[0].slices();
        for (int i = 1; i < iNDArrayArr.length; i++) {
            if (iNDArrayArr[i].slices() != slices) {
                throw new IllegalArgumentException("Unable to parallelize; un equal slices for array " + i);
            }
        }
        CountDownLatch countDownLatch = new CountDownLatch(slices);
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        while (i2 < slices) {
            INDArray[] iNDArrayArr2 = new INDArray[slices];
            while (0 < iNDArrayArr2.length) {
                iNDArrayArr2[0] = iNDArrayArr[0].slice(i2);
                i2++;
            }
            arrayList.add(new ForkJoinArrayINDArrayTask(iNDArrayArr2, iNDArrayTask, countDownLatch));
            i2++;
        }
        return new Pair<>(arrayList, countDownLatch);
    }

    public static Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnSlices(INDArray[] iNDArrayArr, INDArrayTask iNDArrayTask) {
        ArrayList arrayList = new ArrayList();
        int slices = iNDArrayArr[0].slices();
        for (int i = 1; i < iNDArrayArr.length; i++) {
            if (iNDArrayArr[i].slices() != slices) {
                throw new IllegalArgumentException("Unable to parallelize; un equal slices for array " + i);
            }
        }
        CountDownLatch countDownLatch = new CountDownLatch(slices);
        int i2 = 0;
        while (i2 < slices) {
            INDArray[] iNDArrayArr2 = new INDArray[slices];
            while (0 < iNDArrayArr2.length) {
                iNDArrayArr2[0] = iNDArrayArr[0].slice(i2);
                i2++;
            }
            arrayList.add(new RunnableMultipleINDArrayTask(iNDArrayArr2, iNDArrayTask, countDownLatch));
            i2++;
        }
        return new Pair<>(arrayList, countDownLatch);
    }

    public static Pair<List<ForkJoinTask<INDArray>>, CountDownLatch> parititonForkJoinBasedOnSlices(INDArray iNDArray, INDArrayTask iNDArrayTask) {
        ArrayList arrayList = new ArrayList();
        CountDownLatch countDownLatch = new CountDownLatch(iNDArray.slices());
        for (int i = 0; i < iNDArray.slices(); i++) {
            arrayList.add(new ForkJoinINDArrayTask(iNDArray.slice(i), iNDArrayTask, countDownLatch));
        }
        return new Pair<>(arrayList, countDownLatch);
    }

    public static Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnSlices(INDArray iNDArray, INDArrayTask iNDArrayTask) {
        ArrayList arrayList = new ArrayList();
        int slices = iNDArray.slices();
        CountDownLatch countDownLatch = new CountDownLatch(slices);
        for (int i = 0; i < slices; i++) {
            arrayList.add(new RunnableINDArrayTask(iNDArray.slice(i), iNDArrayTask));
        }
        return new Pair<>(arrayList, countDownLatch);
    }

    public static Pair<List<ForkJoinTask<INDArray>>, CountDownLatch> parititonForkJoinBasedOnTensorsAlongDimension(INDArray iNDArray, INDArrayTask iNDArrayTask, int... iArr) {
        ArrayList arrayList = new ArrayList();
        int tensorssAlongDimension = iNDArray.tensorssAlongDimension(iArr);
        CountDownLatch countDownLatch = new CountDownLatch(tensorssAlongDimension);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            arrayList.add(new ForkJoinINDArrayTask(iNDArray.tensorAlongDimension(i, iArr), iNDArrayTask, countDownLatch));
        }
        return new Pair<>(arrayList, countDownLatch);
    }

    public static Pair<List<Runnable>, CountDownLatch> parititonRunnablesBasedOnTensorsAlongDimension(INDArray iNDArray, Op op, OpExecutioner opExecutioner, int... iArr) {
        ArrayList arrayList = new ArrayList();
        int tensorssAlongDimension = iNDArray.tensorssAlongDimension(iArr);
        CountDownLatch countDownLatch = new CountDownLatch(tensorssAlongDimension);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            arrayList.add(new RunnableINDArrayTask(iNDArray, new OpINDArrayTask(op, opExecutioner, i, iArr, countDownLatch)));
        }
        return new Pair<>(arrayList, countDownLatch);
    }

    public static Pair<List<ForkJoinTask<INDArray[]>>, CountDownLatch> parititonForkJoinBasedOnTensorsAlongDimension(INDArray[] iNDArrayArr, INDArrayTask iNDArrayTask, int... iArr) {
        int tensorssAlongDimension = iNDArrayArr[0].tensorssAlongDimension(iArr);
        for (int i = 1; i < iNDArrayArr.length; i++) {
            if (!iNDArrayArr[0].isVector() && iNDArrayArr[i].tensorssAlongDimension(iArr) != tensorssAlongDimension) {
                throw new IllegalArgumentException("Unable to parallellize operations with unequal number of tenosrs along dimension");
            }
        }
        CountDownLatch countDownLatch = new CountDownLatch(tensorssAlongDimension);
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < tensorssAlongDimension; i2++) {
            INDArray[] iNDArrayArr2 = new INDArray[iNDArrayArr.length];
            for (int i3 = 0; i3 < iNDArrayArr2.length; i3++) {
                iNDArrayArr2[i3] = iNDArrayArr[i3].tensorAlongDimension(i2, iArr);
            }
            arrayList.add(new ForkJoinArrayINDArrayTask(iNDArrayArr2, iNDArrayTask, countDownLatch));
        }
        return new Pair<>(arrayList, countDownLatch);
    }
}
