package com.komputation.cuda.layers.forward.maxpooling;

import com.komputation.cuda.CudaFloatArrayKt;
import com.komputation.cuda.CudaIntArrayKt;
import com.komputation.cuda.kernels.Kernel;
import com.komputation.cuda.kernels.launch.KernelLaunchConfiguration;
import com.komputation.cuda.kernels.launch.RowwiseKt;
import com.komputation.cuda.layers.BaseCudaForwardLayer;
import com.komputation.cuda.layers.CudaVariableLengthForwardLayer;
import com.komputation.layers.Resourceful;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import kotlin.Metadata;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CudaMaxPoolingLayer.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��Z\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010\u0015\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\b\u0017\n\u0002\u0010\u0002\n\u0002\b\u0006\n\u0002\u0010\u000b\n\u0002\b\u0002\u0018��2\u00020\u00012\u00020\u00022\u00020\u0003BU\b��\u0012\b\u0010\u0004\u001a\u0004\u0018\u00010\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\u0006\u0010\b\u001a\u00020\u0007\u0012\u0006\u0010\t\u001a\u00020\n\u0012\f\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\r0\f\u0012\f\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\r0\f\u0012\u0006\u0010\u000f\u001a\u00020\u0007\u0012\u0006\u0010\u0010\u001a\u00020\u0007¢\u0006\u0002\u0010\u0011J\u0010\u0010:\u001a\u00020;2\u0006\u0010&\u001a\u00020\u0007H\u0016J\u0018\u0010<\u001a\u00020\u001b2\u0006\u0010\u0013\u001a\u00020\u00072\u0006\u0010=\u001a\u00020\u001bH\u0016J(\u0010>\u001a\u00020\u001b2\u0006\u0010\u0013\u001a\u00020\u00072\u0006\u0010?\u001a\u00020\u001b2\u0006\u0010@\u001a\u00020\u001b2\u0006\u0010A\u001a\u00020BH\u0016J \u0010>\u001a\u00020\u001b2\u0006\u0010\u0013\u001a\u00020\u00072\u0006\u0010@\u001a\u00020\u001b2\u0006\u0010A\u001a\u00020BH\u0016J\b\u0010C\u001a\u00020;H\u0016R\u0010\u0010\u0012\u001a\u0004\u0018\u00010\rX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0013\u001a\u00020\u0014X\u0082\u0004¢\u0006\u0002\n��R\u001a\u0010\u0015\u001a\u00020\u0007X\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b\u0016\u0010\u0017\"\u0004\b\u0018\u0010\u0019R\u0014\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\r0\fX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\r0\fX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u001a\u001a\u00020\u001bX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001c\u0010\u001dR\u0014\u0010\u001e\u001a\u00020\u001bX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001f\u0010\u001dR\u000e\u0010 \u001a\u00020\u001bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010!\u001a\u00020\u001bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\"\u001a\u00020#X\u0082\u000e¢\u0006\u0002\n��R\u0010\u0010$\u001a\u0004\u0018\u00010\rX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010%\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010&\u001a\u00020\u0007X\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010\b\u001a\u00020\u0007X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b'\u0010\u0017R\u000e\u0010(\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0010\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010)\u001a\u00020\u0007X\u0096D¢\u0006\b\n��\u001a\u0004\b*\u0010\u0017R\u0014\u0010+\u001a\u00020\u0007X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b,\u0010\u0017R\u0014\u0010-\u001a\u00020\u0007X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b.\u0010\u0017R\u000e\u0010/\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00100\u001a\n 1*\u0004\u0018\u00010\u001b0\u001bX\u0082\u0004¢\u0006\u0002\n��R\u0016\u00102\u001a\n 1*\u0004\u0018\u00010\u001b0\u001bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u00103\u001a\u00020\u001bX\u0082\u000e¢\u0006\u0002\n��R\u0016\u00104\u001a\n 1*\u0004\u0018\u00010\u001b0\u001bX\u0082\u0004¢\u0006\u0002\n��R\u0016\u00105\u001a\n 1*\u0004\u0018\u00010\u001b0\u001bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u00106\u001a\u00020\u001bX\u0082\u000e¢\u0006\u0002\n��R\u0016\u00107\u001a\n 1*\u0004\u0018\u00010\u001b0\u001bX\u0082\u0004¢\u0006\u0002\n��R\u0016\u00108\u001a\n 1*\u0004\u0018\u00010\u001b0\u001bX\u0082\u0004¢\u0006\u0002\n��R\u0016\u00109\u001a\n 1*\u0004\u0018\u00010\u001b0\u001bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\t\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u000f\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��¨\u0006D"}, d2 = {"Lcom/komputation/cuda/layers/forward/maxpooling/CudaMaxPoolingLayer;", "Lcom/komputation/cuda/layers/BaseCudaForwardLayer;", "Lcom/komputation/layers/Resourceful;", "Lcom/komputation/cuda/layers/CudaVariableLengthForwardLayer;", "name", "", "numberRows", "", "maximumInputColumns", "symbolForUnusedColumns", "", "createForwardKernel", "Lkotlin/Function0;", "Lcom/komputation/cuda/kernels/Kernel;", "createBackwardKernel", "warpSize", "maximumNumberThreadsPerBlock", "(Ljava/lang/String;IIFLkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function0;II)V", "backwardKernel", "batchSize", "", "count", "getCount", "()I", "setCount", "(I)V", "deviceBackwardResult", "Ljcuda/Pointer;", "getDeviceBackwardResult", "()Ljcuda/Pointer;", "deviceForwardResult", "getDeviceForwardResult", "deviceMaxIndices", "deviceMaximumBatchLengths", "forwardConfiguration", "Lcom/komputation/cuda/kernels/launch/KernelLaunchConfiguration;", "forwardKernel", "forwardSharedMemoryBytes", "maximumBatchSize", "getMaximumInputColumns", "maximumNumberEntries", "maximumOutputColumns", "getMaximumOutputColumns", "numberInputRows", "getNumberInputRows", "numberOutputRows", "getNumberOutputRows", "numberWarps", "pointerToBackwardResult", "kotlin.jvm.PlatformType", "pointerToBatchSize", "pointerToForwardBatchLengths", "pointerToForwardResult", "pointerToMaxIndices", "pointerToMaximumBatchLengths", "pointerToMaximumNumberEntries", "pointerToNumberRows", "pointerToSymbolForUnusedColumns", "acquire", "", "backward", "chain", "forward", "deviceLengths", "deviceInput", "isTraining", "", "release", "komputation"})
/* loaded from: input_file:com/komputation/cuda/layers/forward/maxpooling/CudaMaxPoolingLayer.class */
public final class CudaMaxPoolingLayer extends BaseCudaForwardLayer implements Resourceful, CudaVariableLengthForwardLayer {
    private final int maximumNumberEntries;
    private final Pointer pointerToMaximumNumberEntries;
    private final Pointer pointerToNumberRows;
    private Kernel forwardKernel;

    @NotNull
    private final Pointer deviceForwardResult;
    private final Pointer pointerToForwardResult;
    private final int numberOutputRows;
    private final int maximumOutputColumns = 1;
    private Kernel backwardKernel;
    private final int numberInputRows;
    private final int[] batchSize;
    private final Pointer pointerToBatchSize;
    private final Pointer deviceMaxIndices;
    private final Pointer pointerToMaxIndices;
    private KernelLaunchConfiguration forwardConfiguration;
    private final int numberWarps;
    private final int forwardSharedMemoryBytes;
    private int maximumBatchSize;

    @NotNull
    private final Pointer deviceBackwardResult;
    private final Pointer pointerToBackwardResult;
    private final Pointer deviceMaximumBatchLengths;
    private Pointer pointerToMaximumBatchLengths;
    private Pointer pointerToForwardBatchLengths;
    private final Pointer pointerToSymbolForUnusedColumns;
    private int count;
    private final int maximumInputColumns;
    private final float symbolForUnusedColumns;
    private final Function0<Kernel> createForwardKernel;
    private final Function0<Kernel> createBackwardKernel;
    private final int warpSize;
    private final int maximumNumberThreadsPerBlock;

    @Override // com.komputation.cuda.CudaForwardState
    @NotNull
    public Pointer getDeviceForwardResult() {
        return this.deviceForwardResult;
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getNumberOutputRows() {
        return this.numberOutputRows;
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getMaximumOutputColumns() {
        return this.maximumOutputColumns;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getNumberInputRows() {
        return this.numberInputRows;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    @NotNull
    public Pointer getDeviceBackwardResult() {
        return this.deviceBackwardResult;
    }

    @Override // com.komputation.layers.Resourceful
    public void acquire(int i) {
        this.maximumBatchSize = i;
        this.forwardKernel = (Kernel) this.createForwardKernel.invoke();
        this.backwardKernel = (Kernel) this.createBackwardKernel.invoke();
        CudaIntArrayKt.allocateDeviceIntMemory(this.deviceMaxIndices, i * getNumberInputRows());
        CudaFloatArrayKt.allocateDeviceFloatMemory(getDeviceForwardResult(), i * getNumberInputRows());
        CudaFloatArrayKt.allocateDeviceFloatMemory(getDeviceBackwardResult(), i * this.maximumNumberEntries);
        int[] iArr = new int[i];
        int length = iArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            iArr[i2] = getMaximumInputColumns();
        }
        CudaIntArrayKt.setIntArray(iArr, i, this.deviceMaximumBatchLengths);
        Pointer pointer = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceMaximumBatchLengths});
        Intrinsics.checkExpressionValueIsNotNull(pointer, "Pointer.to(this.deviceMaximumBatchLengths)");
        this.pointerToMaximumBatchLengths = pointer;
    }

    public final int getCount() {
        return this.count;
    }

    public final void setCount(int i) {
        this.count = i;
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer forward(int i, @NotNull Pointer pointer, boolean z) {
        Intrinsics.checkParameterIsNotNull(pointer, "deviceInput");
        this.batchSize[0] = i;
        Kernel kernel = this.forwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        Pointer pointer2 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToBatchSize, (NativePointerObject) this.pointerToMaximumBatchLengths, (NativePointerObject) this.pointerToMaximumNumberEntries, (NativePointerObject) this.pointerToMaxIndices, (NativePointerObject) Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer}), (NativePointerObject) this.pointerToForwardResult});
        Intrinsics.checkExpressionValueIsNotNull(pointer2, "Pointer.to(\n            …rwardResult\n            )");
        kernel.launch(pointer2, this.maximumBatchSize, this.forwardConfiguration.getNumberBlocks(), this.forwardConfiguration.getNumberThreadsPerBlock(), this.forwardSharedMemoryBytes);
        this.pointerToForwardBatchLengths = this.pointerToMaximumBatchLengths;
        return getDeviceForwardResult();
    }

    @Override // com.komputation.cuda.layers.CudaVariableLengthForwardLayer
    @NotNull
    public Pointer forward(int i, @NotNull Pointer pointer, @NotNull Pointer pointer2, boolean z) {
        Intrinsics.checkParameterIsNotNull(pointer, "deviceLengths");
        Intrinsics.checkParameterIsNotNull(pointer2, "deviceInput");
        this.batchSize[0] = i;
        NativePointerObject nativePointerObject = Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer});
        Kernel kernel = this.forwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        Pointer pointer3 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToBatchSize, nativePointerObject, (NativePointerObject) this.pointerToMaximumNumberEntries, (NativePointerObject) this.pointerToMaxIndices, (NativePointerObject) Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer2}), (NativePointerObject) this.pointerToForwardResult});
        Intrinsics.checkExpressionValueIsNotNull(pointer3, "Pointer.to(\n            …rwardResult\n            )");
        kernel.launch(pointer3, this.maximumBatchSize, this.forwardConfiguration.getNumberBlocks(), this.forwardConfiguration.getNumberThreadsPerBlock(), this.forwardSharedMemoryBytes);
        Intrinsics.checkExpressionValueIsNotNull(nativePointerObject, "pointerToLengths");
        this.pointerToForwardBatchLengths = nativePointerObject;
        return getDeviceForwardResult();
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer backward(int i, @NotNull Pointer pointer) {
        Intrinsics.checkParameterIsNotNull(pointer, "chain");
        Kernel kernel = this.backwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        Pointer pointer2 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToBatchSize, (NativePointerObject) this.pointerToForwardBatchLengths, (NativePointerObject) this.pointerToSymbolForUnusedColumns, (NativePointerObject) this.pointerToMaximumNumberEntries, (NativePointerObject) this.pointerToNumberRows, (NativePointerObject) this.pointerToMaxIndices, (NativePointerObject) Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer}), (NativePointerObject) this.pointerToBackwardResult});
        Intrinsics.checkExpressionValueIsNotNull(pointer2, "Pointer.to(\n            …kwardResult\n            )");
        kernel.launch(pointer2, this.maximumBatchSize, getNumberInputRows(), getMaximumInputColumns(), 0);
        return getDeviceBackwardResult();
    }

    @Override // com.komputation.layers.Resourceful
    public void release() {
        this.maximumBatchSize = -1;
        Kernel kernel = this.forwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        kernel.destroy();
        Kernel kernel2 = this.backwardKernel;
        if (kernel2 == null) {
            Intrinsics.throwNpe();
        }
        kernel2.destroy();
        JCuda.cudaFree(getDeviceForwardResult());
        JCuda.cudaFree(getDeviceBackwardResult());
        JCuda.cudaFree(this.deviceMaxIndices);
        JCuda.cudaFree(this.deviceMaximumBatchLengths);
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getMaximumInputColumns() {
        return this.maximumInputColumns;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public CudaMaxPoolingLayer(@Nullable String str, int i, int i2, float f, @NotNull Function0<Kernel> function0, @NotNull Function0<Kernel> function02, int i3, int i4) {
        super(str);
        Intrinsics.checkParameterIsNotNull(function0, "createForwardKernel");
        Intrinsics.checkParameterIsNotNull(function02, "createBackwardKernel");
        this.maximumInputColumns = i2;
        this.symbolForUnusedColumns = f;
        this.createForwardKernel = function0;
        this.createBackwardKernel = function02;
        this.warpSize = i3;
        this.maximumNumberThreadsPerBlock = i4;
        this.maximumNumberEntries = i * getMaximumInputColumns();
        this.pointerToMaximumNumberEntries = Pointer.to(new int[]{this.maximumNumberEntries});
        this.pointerToNumberRows = Pointer.to(new int[]{i});
        this.deviceForwardResult = new Pointer();
        this.pointerToForwardResult = Pointer.to(new NativePointerObject[]{(NativePointerObject) getDeviceForwardResult()});
        this.numberOutputRows = i;
        this.maximumOutputColumns = 1;
        this.numberInputRows = i;
        this.batchSize = new int[]{-1};
        this.pointerToBatchSize = Pointer.to(this.batchSize);
        this.deviceMaxIndices = new Pointer();
        this.pointerToMaxIndices = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceMaxIndices});
        this.forwardConfiguration = RowwiseKt.computeRowwiseLaunchConfiguration(getNumberInputRows(), getMaximumInputColumns(), this.warpSize, this.maximumNumberThreadsPerBlock);
        this.numberWarps = ((getMaximumInputColumns() + this.warpSize) - 1) / this.warpSize;
        this.forwardSharedMemoryBytes = (int) CudaIntArrayKt.computeDeviceIntArraySize(this.numberWarps);
        this.maximumBatchSize = -1;
        this.deviceBackwardResult = new Pointer();
        this.pointerToBackwardResult = Pointer.to(new NativePointerObject[]{(NativePointerObject) getDeviceBackwardResult()});
        this.deviceMaximumBatchLengths = new Pointer();
        this.pointerToMaximumBatchLengths = new Pointer();
        this.pointerToForwardBatchLengths = new Pointer();
        this.pointerToSymbolForUnusedColumns = Pointer.to(new float[]{this.symbolForUnusedColumns});
    }
}
