diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 9b2331876b..fc3d25dfc8 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -991,9 +991,8 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - - if (BlocksM * BlocksN < nsm) { - const unsigned int ksplit = 8; + const unsigned int ksplit = 8; + if (BlocksM * BlocksN < nsm && P.c > 8 * ksplit) { ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n); cudaFuncSetAttribute(conv2d_implicit_kernel, diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index e5da8ab056..41807a6b80 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -324,6 +324,7 @@ int main(void) std::make_tuple(256,128,832,1216,3,3), std::make_tuple(256,256,832,1216,3,3), // std::make_tuple(320,256,1024,1920) + // std::make_tuple(32,64,58,58,3,3) }; int k = 0;