diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index fe55de4b91..d2d775c9b2 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -984,9 +984,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int NumThreads = ThreadsM * ThreadsN; const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); - const unsigned int K2MN = 8; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - if (P.c * P.r * P.s > K2MN * P.n * P.Oh * P.Ow || P.c * P.r * P.s > K2MN * P.k) { + if (BlocksM * BlocksN < nsm) { const unsigned int ksplit = 8; ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n);