use a better criterian to use split-k

This commit is contained in:
bssrdf 2025-11-05 13:58:25 -05:00
parent 688de6d7d8
commit d9a48580fc
1 changed files with 2 additions and 2 deletions

View File

@ -984,9 +984,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
const unsigned int K2MN = 8;
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
if (P.c * P.r * P.s > K2MN * P.n * P.Oh * P.Ow || P.c * P.r * P.s > K2MN * P.k) {
if (BlocksM * BlocksN < nsm) {
const unsigned int ksplit = 8;
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n);