From d1fd632ab800ba2fb86f6da231d6763057603814 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Mon, 30 Mar 2026 12:22:43 +0530 Subject: [PATCH] Use the new kernel only for nblocks_stream_k_raw > 4 * ntiles_dst to make sure we have enough concurrency on GPUs --- ggml/src/ggml-cuda/fattn-common.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 1e2210b309..cf610c2c25 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -748,7 +748,7 @@ static __global__ void flash_attn_stream_k_fixup( *dst = dst_val / rowsum; } -// Fallback fixup kernel for the cases where nblocks_stream_k < 2 * ntiles_dst +// Fallback fixup kernel for the cases where nblocks_stream_k < 4 * ntiles_dst // (blocks_num.x not a multiple of ntiles_dst) template __launch_bounds__(D, 1) @@ -1054,8 +1054,8 @@ void launch_fattn( const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst); // Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks. - // do this only if nblocks_stream_k_raw is at least 2x ntiles_dst to avoid excessive loss of occupancy - const int nblocks_stream_k = nblocks_stream_k_raw > 2 * ntiles_dst + // do this only if nblocks_stream_k_raw is at least 4x ntiles_dst to avoid excessive loss of occupancy + const int nblocks_stream_k = nblocks_stream_k_raw > 4 * ntiles_dst ? (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst : nblocks_stream_k_raw;