fix FA rowsum/offset
This commit is contained in:
parent
0c21677e43
commit
23be8be3d7
|
|
@ -716,7 +716,11 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
}
|
||||
|
||||
// Write back final result:
|
||||
*dst = dst_val / rowsum;
|
||||
if (!(rowsum > 0.0f)) {
|
||||
*dst = 0.0f;
|
||||
} else {
|
||||
*dst = dst_val / rowsum;
|
||||
}
|
||||
}
|
||||
|
||||
template<int D> // D == head size
|
||||
|
|
|
|||
|
|
@ -561,6 +561,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
KQ_max_new[col] = KQ_max[col];
|
||||
}
|
||||
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
||||
constexpr int log2_nbatch_fa =
|
||||
nbatch_fa == 256 ? 8 :
|
||||
nbatch_fa == 128 ? 7 :
|
||||
nbatch_fa == 64 ? 6 :
|
||||
nbatch_fa == 32 ? 5 :
|
||||
nbatch_fa == 16 ? 4 :
|
||||
nbatch_fa == 8 ? 3 : 0;
|
||||
static_assert(log2_nbatch_fa != 0, "unexpected nbatch_fa");
|
||||
|
||||
constexpr float kq_max_offset = FATTN_KQ_MAX_OFFSET + (np == 1 ? (log2_nbatch_fa - 3) * 0.69314718f : 0.0f);
|
||||
|
||||
if constexpr (cols_per_warp == 8) {
|
||||
if (ncols2 > 1 || mask_h) {
|
||||
|
|
@ -591,7 +601,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
// Turing + Volta:
|
||||
const int KQ_idx = l % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + kq_max_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -655,7 +665,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
// Turing + Volta:
|
||||
const int KQ_idx = (l/2) % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + kq_max_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1429,8 +1439,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
|
||||
if (!needs_fixup && !is_fixup) {
|
||||
const float KQ_rowsum_j = meta_j[1];
|
||||
dstk_val.x /= KQ_rowsum_j;
|
||||
dstk_val.y /= KQ_rowsum_j;
|
||||
if (!(KQ_rowsum_j > 0.0f)) {
|
||||
dstk_val = make_float2(0.0f, 0.0f);
|
||||
} else {
|
||||
dstk_val.x /= KQ_rowsum_j;
|
||||
dstk_val.y /= KQ_rowsum_j;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_fixup) {
|
||||
|
|
|
|||
|
|
@ -453,7 +453,11 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
float dst_val = VKQ[j_VKQ*D_padded + i];
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= KQ_rowsum_j;
|
||||
if (!(KQ_rowsum_j > 0.0f)) {
|
||||
dst_val = 0.0f;
|
||||
} else {
|
||||
dst_val /= KQ_rowsum_j;
|
||||
}
|
||||
}
|
||||
dst[j_dst_unrolled*D + i] = dst_val;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue