From 99c3df8219e49c18e6cdf1732a817c34138962a4 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Mon, 30 Mar 2026 00:32:26 +0530 Subject: [PATCH 1/2] Write an optimized flash_attn_stream_k_fixup kernel Write a specialized and more optimized kernel for cases where nblocks_stream_k is multiple of ntiles_dst. Make nblocks_stream_k to multiple of ntiles_dst if nblocks_stream_k > 2 * ntiles_dst --- ggml/src/ggml-cuda/fattn-common.cuh | 98 +++++++++++++++++++++++++++-- 1 file changed, 94 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index c59a4db399..1e2210b309 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -674,9 +674,85 @@ static __global__ void flash_attn_mask_to_KV_max( KV_max[sequence*ne31 + jt] = KV_max_sj; } -template // D == head size +template __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, + const int ne11, const int ne12, const int nbatch_fa, const int nblocks_stream_k) { + constexpr int ncols = ncols1*ncols2; + + const int tile_idx = blockIdx.x; // One block per output tile. + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + // nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks. + const int blocks_per_tile = nblocks_stream_k / gridDim.x; + + const int b_first = tile_idx * blocks_per_tile; + const int b_last = b_first + blocks_per_tile - 1; + + const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols); + + const int gqa_ratio = ne02 / ne12; + + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; + + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = tile_idx /(iter_j*iter_z_gqa*ne12); + const int z_KV = (tile_idx - iter_j*iter_z_gqa*ne12 * sequence)/(iter_j*iter_z_gqa); + const int zt_gqa = (tile_idx - iter_j*iter_z_gqa*ne12 * sequence - iter_j*iter_z_gqa * z_KV)/iter_j; + const int jt = tile_idx - iter_j*iter_z_gqa*ne12 * sequence - iter_j*iter_z_gqa * z_KV - iter_j * zt_gqa; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { + return; + } + + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup + float dst_val = *dst; + float max_val; + float rowsum; + { + const float2 tmp = dst_fixup[b_last*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + // Combine with all previous blocks in this tile. + for (int bidx = b_last - 1; bidx >= b_first; --bidx) { + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc]; + + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +// Fallback fixup kernel for the cases where nblocks_stream_k < 2 * ntiles_dst +// (blocks_num.x not a multiple of ntiles_dst) +template +__launch_bounds__(D, 1) +static __global__ void flash_attn_stream_k_fixup_fallback( float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11, const int ne12, const int nbatch_fa) { constexpr int ncols = ncols1*ncols2; @@ -976,7 +1052,12 @@ void launch_fattn( const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves); - const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst); + const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst); + // Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks. + // do this only if nblocks_stream_k_raw is at least 2x ntiles_dst to avoid excessive loss of occupancy + const int nblocks_stream_k = nblocks_stream_k_raw > 2 * ntiles_dst + ? (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst + : nblocks_stream_k_raw; const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; @@ -1063,11 +1144,20 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); if (stream_k) { - if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) { + // Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile. + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2}; + + flash_attn_stream_k_fixup + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa, (int)blocks_num.x); + } else if (ntiles_dst % blocks_num.x != 0) { + // Fallback fixup for the cases where nblocks_stream_k < ntiles_dst. const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup + flash_attn_stream_k_fixup_fallback <<>> ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa); } From d1fd632ab800ba2fb86f6da231d6763057603814 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Mon, 30 Mar 2026 12:22:43 +0530 Subject: [PATCH 2/2] Use the new kernel only for nblocks_stream_k_raw > 4 * ntiles_dst to make sure we have enough concurrency on GPUs --- ggml/src/ggml-cuda/fattn-common.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 1e2210b309..cf610c2c25 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -748,7 +748,7 @@ static __global__ void flash_attn_stream_k_fixup( *dst = dst_val / rowsum; } -// Fallback fixup kernel for the cases where nblocks_stream_k < 2 * ntiles_dst +// Fallback fixup kernel for the cases where nblocks_stream_k < 4 * ntiles_dst // (blocks_num.x not a multiple of ntiles_dst) template __launch_bounds__(D, 1) @@ -1054,8 +1054,8 @@ void launch_fattn( const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst); // Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks. - // do this only if nblocks_stream_k_raw is at least 2x ntiles_dst to avoid excessive loss of occupancy - const int nblocks_stream_k = nblocks_stream_k_raw > 2 * ntiles_dst + // do this only if nblocks_stream_k_raw is at least 4x ntiles_dst to avoid excessive loss of occupancy + const int nblocks_stream_k = nblocks_stream_k_raw > 4 * ntiles_dst ? (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst : nblocks_stream_k_raw;