fix FA rowsum/offset

This commit is contained in:
rmatif 2026-01-26 01:02:16 +01:00
parent 0c21677e43
commit 23be8be3d7
3 changed files with 28 additions and 6 deletions

View File

@ -716,7 +716,11 @@ static __global__ void flash_attn_stream_k_fixup(
} }
// Write back final result: // Write back final result:
*dst = dst_val / rowsum; if (!(rowsum > 0.0f)) {
*dst = 0.0f;
} else {
*dst = dst_val / rowsum;
}
} }
template<int D> // D == head size template<int D> // D == head size

View File

@ -561,6 +561,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
KQ_max_new[col] = KQ_max[col]; KQ_max_new[col] = KQ_max[col];
} }
float KQ_rowsum_add[cols_per_thread] = {0.0f}; float KQ_rowsum_add[cols_per_thread] = {0.0f};
constexpr int log2_nbatch_fa =
nbatch_fa == 256 ? 8 :
nbatch_fa == 128 ? 7 :
nbatch_fa == 64 ? 6 :
nbatch_fa == 32 ? 5 :
nbatch_fa == 16 ? 4 :
nbatch_fa == 8 ? 3 : 0;
static_assert(log2_nbatch_fa != 0, "unexpected nbatch_fa");
constexpr float kq_max_offset = FATTN_KQ_MAX_OFFSET + (np == 1 ? (log2_nbatch_fa - 3) * 0.69314718f : 0.0f);
if constexpr (cols_per_warp == 8) { if constexpr (cols_per_warp == 8) {
if (ncols2 > 1 || mask_h) { if (ncols2 > 1 || mask_h) {
@ -591,7 +601,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
// Turing + Volta: // Turing + Volta:
const int KQ_idx = l % 2; const int KQ_idx = l % 2;
#endif // defined(AMD_WMMA_AVAILABLE) #endif // defined(AMD_WMMA_AVAILABLE)
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + kq_max_offset);
} }
} }
} }
@ -655,7 +665,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
// Turing + Volta: // Turing + Volta:
const int KQ_idx = (l/2) % 2; const int KQ_idx = (l/2) % 2;
#endif // defined(AMD_WMMA_AVAILABLE) #endif // defined(AMD_WMMA_AVAILABLE)
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + kq_max_offset);
} }
} }
} }
@ -1429,8 +1439,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
if (!needs_fixup && !is_fixup) { if (!needs_fixup && !is_fixup) {
const float KQ_rowsum_j = meta_j[1]; const float KQ_rowsum_j = meta_j[1];
dstk_val.x /= KQ_rowsum_j; if (!(KQ_rowsum_j > 0.0f)) {
dstk_val.y /= KQ_rowsum_j; dstk_val = make_float2(0.0f, 0.0f);
} else {
dstk_val.x /= KQ_rowsum_j;
dstk_val.y /= KQ_rowsum_j;
}
} }
if (is_fixup) { if (is_fixup) {

View File

@ -453,7 +453,11 @@ static __global__ void flash_attn_ext_f16(
} }
float dst_val = VKQ[j_VKQ*D_padded + i]; float dst_val = VKQ[j_VKQ*D_padded + i];
if (gridDim.y == 1) { if (gridDim.y == 1) {
dst_val /= KQ_rowsum_j; if (!(KQ_rowsum_j > 0.0f)) {
dst_val = 0.0f;
} else {
dst_val /= KQ_rowsum_j;
}
} }
dst[j_dst_unrolled*D + i] = dst_val; dst[j_dst_unrolled*D + i] = dst_val;
} }