package jama.gpu;

import com.nativelibs4java.opencl.CLBuffer;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLDevice;
import com.nativelibs4java.opencl.CLEvent;
import com.nativelibs4java.opencl.CLException;
import com.nativelibs4java.opencl.CLMem;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.JavaCL;
import jama.FloatMatrix;
import jama.Matrix;
import java.io.IOException;
import org.bridj.Pointer;

/* loaded from: input_file:jama/gpu/GPU.class */
public class GPU {
    private final CLContext context;
    private static CLContext defaultContext;

    public static GPU create() {
        if (defaultContext == null) {
            defaultContext = JavaCL.createBestContext();
        }
        return new GPU(defaultContext);
    }

    public static GPU create(CLContext cLContext) {
        return new GPU(cLContext);
    }

    private GPU(CLContext cLContext) {
        this.context = cLContext;
    }

    public FloatMatrix multiply(FloatMatrix floatMatrix, FloatMatrix floatMatrix2) throws IOException {
        return multiply(floatMatrix, floatMatrix2, true);
    }

    public FloatMatrix multiplyLocal(FloatMatrix floatMatrix, FloatMatrix floatMatrix2) throws IOException {
        return multiply(floatMatrix, floatMatrix2, false);
    }

    private FloatMatrix multiply(FloatMatrix floatMatrix, FloatMatrix floatMatrix2, boolean z) throws IOException {
        if (floatMatrix.getColumnDimension() != floatMatrix2.getRowDimension()) {
            throw new IllegalArgumentException("Matrix inner dimensions must agree.");
        }
        CLQueue createDefaultQueue = this.context.createDefaultQueue(new CLDevice.QueueProperties[0]);
        FloatMatrix zeroPadding = zeroPadding(floatMatrix, 16);
        FloatMatrix zeroPadding2 = zeroPadding(floatMatrix2, 16);
        int rowDimension = zeroPadding.getRowDimension() * zeroPadding2.getColumnDimension();
        Pointer<Float> matrixToPointer = matrixToPointer(zeroPadding);
        Pointer<Float> matrixToPointer2 = matrixToPointer(zeroPadding2);
        Pointer allocateFloats = Pointer.allocateFloats(rowDimension);
        Pointer allocateInt = Pointer.allocateInt();
        allocateInt.set(Integer.valueOf(zeroPadding.getColumnDimension()));
        CLBuffer<Float> createBuffer = this.context.createBuffer(CLMem.Usage.Input, matrixToPointer);
        CLBuffer<Float> createBuffer2 = this.context.createBuffer(CLMem.Usage.Input, matrixToPointer2);
        CLBuffer<Integer> createIntBuffer = this.context.createIntBuffer(CLMem.Usage.Input, allocateInt);
        CLBuffer<Float> createBuffer3 = this.context.createBuffer(CLMem.Usage.Output, allocateFloats);
        MultiplicationKernel multiplicationKernel = new MultiplicationKernel(this.context);
        int[] iArr = {16, 16};
        int[] iArr2 = {zeroPadding.getRowDimension(), zeroPadding2.getColumnDimension()};
        CLEvent cLEvent = null;
        Pointer pointer = null;
        try {
            try {
                cLEvent = z ? multiplicationKernel.floatMatrixMultLocals(createDefaultQueue, createBuffer3, createBuffer, createBuffer2, createIntBuffer, iArr2, iArr, new CLEvent[0]) : multiplicationKernel.floatMatrixMult(createDefaultQueue, createBuffer3, createBuffer, createBuffer2, createIntBuffer, iArr2, iArr, new CLEvent[0]);
                pointer = createBuffer3.read(createDefaultQueue, new CLEvent[]{cLEvent});
                FloatMatrix removeZeroPadding = removeZeroPadding(pointerToFloatMatrix(pointer, zeroPadding.getRowDimension(), zeroPadding2.getColumnDimension()), floatMatrix, floatMatrix2);
                Pointer.release(new Pointer[]{matrixToPointer, matrixToPointer2, pointer, allocateFloats, allocateInt});
                createBuffer.release();
                createBuffer2.release();
                createIntBuffer.release();
                createBuffer3.release();
                createDefaultQueue.release();
                cLEvent.release();
                return removeZeroPadding;
            } catch (CLException e) {
                e.printStackTrace();
                throw e;
            }
        } catch (Throwable th) {
            Pointer.release(new Pointer[]{matrixToPointer, matrixToPointer2, pointer, allocateFloats, allocateInt});
            createBuffer.release();
            createBuffer2.release();
            createIntBuffer.release();
            createBuffer3.release();
            createDefaultQueue.release();
            cLEvent.release();
            throw th;
        }
    }

    protected static Pointer<Double> matrixToPointer(Matrix matrix) {
        Pointer<Double> allocateDoubles = Pointer.allocateDoubles(matrix.getColumnDimension() * matrix.getRowDimension());
        for (int i = 0; i < matrix.getRowDimension(); i++) {
            for (int i2 = 0; i2 < matrix.getColumnDimension(); i2++) {
                allocateDoubles.set(i + (matrix.getRowDimension() * i2), Double.valueOf(matrix.get(i, i2)));
            }
        }
        return allocateDoubles;
    }

    protected static Pointer<Float> matrixToPointer(FloatMatrix floatMatrix) {
        Pointer<Float> allocateFloats = Pointer.allocateFloats(floatMatrix.getColumnDimension() * floatMatrix.getRowDimension());
        for (int i = 0; i < floatMatrix.getRowDimension(); i++) {
            for (int i2 = 0; i2 < floatMatrix.getColumnDimension(); i2++) {
                allocateFloats.set(i + (floatMatrix.getRowDimension() * i2), Float.valueOf(floatMatrix.get(i, i2)));
            }
        }
        return allocateFloats;
    }

    protected static Matrix pointerToMatrix(Pointer<Double> pointer, int i, int i2) {
        Matrix matrix = new Matrix(i, i2);
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                matrix.set(i3, i4, ((Double) pointer.get(i3 + (matrix.getRowDimension() * i4))).doubleValue());
            }
        }
        return matrix;
    }

    protected static FloatMatrix pointerToFloatMatrix(Pointer<Float> pointer, int i, int i2) {
        FloatMatrix floatMatrix = new FloatMatrix(i, i2);
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                floatMatrix.set(i3, i4, ((Float) pointer.get(i3 + (floatMatrix.getRowDimension() * i4))).floatValue());
            }
        }
        return floatMatrix;
    }

    protected static int workgroupSize(int i, int i2) {
        if (i <= i2) {
            return i2;
        }
        int i3 = i % i2;
        return i3 == 0 ? i : (i + i2) - i3;
    }

    protected static FloatMatrix zeroPadding(FloatMatrix floatMatrix, int i) {
        int workgroupSize = workgroupSize(floatMatrix.getRowDimension(), i);
        int workgroupSize2 = workgroupSize(floatMatrix.getColumnDimension(), i);
        if (workgroupSize == floatMatrix.getRowDimension() && workgroupSize2 == floatMatrix.getColumnDimension()) {
            return floatMatrix;
        }
        FloatMatrix floatMatrix2 = new FloatMatrix(workgroupSize, workgroupSize2);
        floatMatrix2.setFloatMatrix(0, floatMatrix.getRowDimension() - 1, 0, floatMatrix.getColumnDimension() - 1, floatMatrix);
        return floatMatrix2;
    }

    protected static FloatMatrix removeZeroPadding(FloatMatrix floatMatrix, FloatMatrix floatMatrix2, FloatMatrix floatMatrix3) {
        return (floatMatrix.getColumnDimension() == floatMatrix3.getColumnDimension() && floatMatrix.getRowDimension() == floatMatrix2.getRowDimension()) ? floatMatrix : floatMatrix.getFloatMatrix(0, floatMatrix2.getRowDimension() - 1, 0, floatMatrix3.getColumnDimension() - 1);
    }
}
