CUDA: fix negative KV_max values in FA (#15321)

This commit is contained in:
Johannes Gäßler 2025-08-14 23:21:24 +02:00 committed by GitHub
parent df36bce667
commit 4227c9be42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 1 deletions

View File

@ -539,11 +539,15 @@ static __global__ void flash_attn_mask_to_KV_max(
all_inf = warp_reduce_all(all_inf);
if (!all_inf) {
KV_max_sj += FATTN_KQ_STRIDE;
break;
}
}
// If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
// If the break was triggered it's the lower edge of the tile with the first non-masked values.
// In either case, walk back the decrementation by FATTN_KQ_STRIDE.
KV_max_sj += FATTN_KQ_STRIDE;
if (threadIdx.x != 0) {
return;
}