package com.komputation.optimization;

import com.komputation.cpu.optimization.CpuStochasticGradientDescent;
import com.komputation.cuda.CudaContext;
import com.komputation.cuda.kernels.Kernel;
import com.komputation.cuda.kernels.OptimizationKernels;
import com.komputation.cuda.optimization.CudaStochasticGradientDescent;
import kotlin.Metadata;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.functions.Function3;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: StochasticGradientDescent.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��0\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0007\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0010\b\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\u0018��2\u00020\u0001B\u000f\b��\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u001a\u0010\u0005\u001a\u0014\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\b0\u0006H\u0016J(\u0010\t\u001a\u001a\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u0007\u0012\u0004\u0012\u00020\u000b0\n2\u0006\u0010\f\u001a\u00020\rH\u0016R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��¨\u0006\u000e"}, d2 = {"Lcom/komputation/optimization/StochasticGradientDescent;", "Lcom/komputation/optimization/OptimizationInstruction;", "learningRate", "", "(F)V", "buildForCpu", "Lkotlin/Function2;", "", "Lcom/komputation/cpu/optimization/CpuStochasticGradientDescent;", "buildForCuda", "Lkotlin/Function3;", "Lcom/komputation/cuda/optimization/CudaStochasticGradientDescent;", "context", "Lcom/komputation/cuda/CudaContext;", "komputation"})
/* loaded from: input_file:com/komputation/optimization/StochasticGradientDescent.class */
public final class StochasticGradientDescent implements OptimizationInstruction {
    private final float learningRate;

    @Override // com.komputation.optimization.CpuOptimizationInstruction
    @NotNull
    public Function2<Integer, Integer, CpuStochasticGradientDescent> buildForCpu() {
        return new Function2<Integer, Integer, CpuStochasticGradientDescent>() { // from class: com.komputation.optimization.StochasticGradientDescent$buildForCpu$1
            public /* bridge */ /* synthetic */ Object invoke(Object obj, Object obj2) {
                return invoke(((Number) obj).intValue(), ((Number) obj2).intValue());
            }

            @NotNull
            public final CpuStochasticGradientDescent invoke(int i, int i2) {
                float f;
                f = StochasticGradientDescent.this.learningRate;
                return new CpuStochasticGradientDescent(f);
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            {
                super(2);
            }
        };
    }

    @Override // com.komputation.optimization.CudaOptimizationInstruction
    @NotNull
    public Function3<Integer, Integer, Integer, CudaStochasticGradientDescent> buildForCuda(@NotNull final CudaContext cudaContext) {
        Intrinsics.checkParameterIsNotNull(cudaContext, "context");
        return new Function3<Integer, Integer, Integer, CudaStochasticGradientDescent>() { // from class: com.komputation.optimization.StochasticGradientDescent$buildForCuda$1
            public /* bridge */ /* synthetic */ Object invoke(Object obj, Object obj2, Object obj3) {
                return invoke(((Number) obj).intValue(), ((Number) obj2).intValue(), ((Number) obj3).intValue());
            }

            @NotNull
            public final CudaStochasticGradientDescent invoke(int i, int i2, int i3) {
                float f;
                f = StochasticGradientDescent.this.learningRate;
                return new CudaStochasticGradientDescent(i2 * i3, f, new Function0<Kernel>() { // from class: com.komputation.optimization.StochasticGradientDescent$buildForCuda$1.1
                    @NotNull
                    public final Kernel invoke() {
                        return cudaContext.createKernel(OptimizationKernels.INSTANCE.stochasticGradientDescent());
                    }

                    {
                        super(0);
                    }
                }, cudaContext.getNumberMultiprocessors(), cudaContext.getMaximumNumberOfResidentWarpsPerMultiprocessor(), cudaContext.getWarpSize(), cudaContext.getMaximumNumberOfThreadsPerBlock());
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(3);
            }
        };
    }

    public StochasticGradientDescent(float f) {
        this.learningRate = f;
    }
}
