Use the new kernel only for nblocks_stream_k_raw > 4 * ntiles_dst to make sure we have enough concurrency on GPUs

This commit is contained in:
Gaurav Garg 2026-03-30 12:22:43 +05:30
parent 99c3df8219
commit d1fd632ab8
1 changed files with 3 additions and 3 deletions

View File

@ -748,7 +748,7 @@ static __global__ void flash_attn_stream_k_fixup(
*dst = dst_val / rowsum;
}
// Fallback fixup kernel for the cases where nblocks_stream_k < 2 * ntiles_dst
// Fallback fixup kernel for the cases where nblocks_stream_k < 4 * ntiles_dst
// (blocks_num.x not a multiple of ntiles_dst)
template <int D, int ncols1, int ncols2>
__launch_bounds__(D, 1)
@ -1054,8 +1054,8 @@ void launch_fattn(
const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
// Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks.
// do this only if nblocks_stream_k_raw is at least 2x ntiles_dst to avoid excessive loss of occupancy
const int nblocks_stream_k = nblocks_stream_k_raw > 2 * ntiles_dst
// do this only if nblocks_stream_k_raw is at least 4x ntiles_dst to avoid excessive loss of occupancy
const int nblocks_stream_k = nblocks_stream_k_raw > 4 * ntiles_dst
? (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst
: nblocks_stream_k_raw;