This commit is contained in:
rmatif 2026-02-06 22:14:31 +01:00 committed by GitHub
commit 65c8a07b4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 6 deletions

View File

@ -722,7 +722,11 @@ static __global__ void flash_attn_stream_k_fixup(
}
// Write back final result:
*dst = dst_val / rowsum;
if (!(rowsum > 0.0f)) {
*dst = 0.0f;
} else {
*dst = dst_val / rowsum;
}
}
template<int D> // D == head size

View File

@ -561,6 +561,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
KQ_max_new[col] = KQ_max[col];
}
float KQ_rowsum_add[cols_per_thread] = {0.0f};
constexpr int log2_nbatch_fa =
nbatch_fa == 256 ? 8 :
nbatch_fa == 128 ? 7 :
nbatch_fa == 64 ? 6 :
nbatch_fa == 32 ? 5 :
nbatch_fa == 16 ? 4 :
nbatch_fa == 8 ? 3 : 0;
static_assert(log2_nbatch_fa != 0, "unexpected nbatch_fa");
constexpr float kq_max_offset = FATTN_KQ_MAX_OFFSET + (np == 1 ? (log2_nbatch_fa - 3) * 0.69314718f : 0.0f);
if constexpr (cols_per_warp == 8) {
if (ncols2 > 1 || mask_h) {
@ -591,7 +601,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
// Turing + Volta:
const int KQ_idx = l % 2;
#endif // defined(AMD_WMMA_AVAILABLE)
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + kq_max_offset);
}
}
}
@ -655,7 +665,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
// Turing + Volta:
const int KQ_idx = (l/2) % 2;
#endif // defined(AMD_WMMA_AVAILABLE)
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + kq_max_offset);
}
}
}
@ -1430,8 +1440,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
if (!needs_fixup && !is_fixup) {
const float KQ_rowsum_j = meta_j[1];
dstk_val.x /= KQ_rowsum_j;
dstk_val.y /= KQ_rowsum_j;
if (!(KQ_rowsum_j > 0.0f)) {
dstk_val = make_float2(0.0f, 0.0f);
} else {
dstk_val.x /= KQ_rowsum_j;
dstk_val.y /= KQ_rowsum_j;
}
}
if (is_fixup) {

View File

@ -453,7 +453,11 @@ static __global__ void flash_attn_ext_f16(
}
float dst_val = VKQ[j_VKQ*D_padded + i];
if (gridDim.y == 1) {
dst_val /= KQ_rowsum_j;
if (!(KQ_rowsum_j > 0.0f)) {
dst_val = 0.0f;
} else {
dst_val /= KQ_rowsum_j;
}
}
dst[j_dst_unrolled*D + i] = dst_val;
}