This commit is contained in:
Gaurav Garg 2026-04-01 12:15:43 +01:00 committed by GitHub
commit 4126a9c6db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 94 additions and 4 deletions

View File

@ -674,9 +674,85 @@ static __global__ void flash_attn_mask_to_KV_max(
KV_max[sequence*ne31 + jt] = KV_max_sj;
}
template<int D, int ncols1, int ncols2> // D == head size
template<int D, int ncols1, int ncols2>
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
const int ne11, const int ne12, const int nbatch_fa, const int nblocks_stream_k) {
constexpr int ncols = ncols1*ncols2;
const int tile_idx = blockIdx.x; // One block per output tile.
const int j = blockIdx.y;
const int c = blockIdx.z;
const int jc = j*ncols2 + c;
const int tid = threadIdx.x;
// nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks.
const int blocks_per_tile = nblocks_stream_k / gridDim.x;
const int b_first = tile_idx * blocks_per_tile;
const int b_last = b_first + blocks_per_tile - 1;
const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols);
const int gqa_ratio = ne02 / ne12;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
const int sequence = tile_idx /(iter_j*iter_z_gqa*ne12);
const int z_KV = (tile_idx - iter_j*iter_z_gqa*ne12 * sequence)/(iter_j*iter_z_gqa);
const int zt_gqa = (tile_idx - iter_j*iter_z_gqa*ne12 * sequence - iter_j*iter_z_gqa * z_KV)/iter_j;
const int jt = tile_idx - iter_j*iter_z_gqa*ne12 * sequence - iter_j*iter_z_gqa * z_KV - iter_j * zt_gqa;
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
return;
}
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup
float dst_val = *dst;
float max_val;
float rowsum;
{
const float2 tmp = dst_fixup[b_last*ncols + jc];
max_val = tmp.x;
rowsum = tmp.y;
}
// Combine with all previous blocks in this tile.
for (int bidx = b_last - 1; bidx >= b_first; --bidx) {
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc];
const float max_val_new = fmaxf(max_val, tmp.x);
const float diff_val = max_val - max_val_new;
const float diff_add = tmp.x - max_val_new;
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
dst_val = scale_val*dst_val + scale_add*dst_add;
rowsum = scale_val*rowsum + scale_add*tmp.y;
max_val = max_val_new;
}
// Write back final result:
*dst = dst_val / rowsum;
}
// Fallback fixup kernel for the cases where nblocks_stream_k < 4 * ntiles_dst
// (blocks_num.x not a multiple of ntiles_dst)
template <int D, int ncols1, int ncols2>
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup_fallback(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
const int ne11, const int ne12, const int nbatch_fa) {
constexpr int ncols = ncols1*ncols2;
@ -976,7 +1052,12 @@ void launch_fattn(
const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
// Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks.
// do this only if nblocks_stream_k_raw is at least 4x ntiles_dst to avoid excessive loss of occupancy
const int nblocks_stream_k = nblocks_stream_k_raw > 4 * ntiles_dst
? (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst
: nblocks_stream_k_raw;
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
@ -1063,11 +1144,20 @@ void launch_fattn(
CUDA_CHECK(cudaGetLastError());
if (stream_k) {
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) {
// Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile.
const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2};
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa, (int)blocks_num.x);
} else if (ntiles_dst % blocks_num.x != 0) {
// Fallback fixup for the cases where nblocks_stream_k < ntiles_dst.
const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
flash_attn_stream_k_fixup_fallback<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
}