fix FA rowsum/offset
This commit is contained in:
parent
0c21677e43
commit
23be8be3d7
|
|
@ -716,7 +716,11 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write back final result:
|
// Write back final result:
|
||||||
*dst = dst_val / rowsum;
|
if (!(rowsum > 0.0f)) {
|
||||||
|
*dst = 0.0f;
|
||||||
|
} else {
|
||||||
|
*dst = dst_val / rowsum;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int D> // D == head size
|
template<int D> // D == head size
|
||||||
|
|
|
||||||
|
|
@ -561,6 +561,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
KQ_max_new[col] = KQ_max[col];
|
KQ_max_new[col] = KQ_max[col];
|
||||||
}
|
}
|
||||||
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
||||||
|
constexpr int log2_nbatch_fa =
|
||||||
|
nbatch_fa == 256 ? 8 :
|
||||||
|
nbatch_fa == 128 ? 7 :
|
||||||
|
nbatch_fa == 64 ? 6 :
|
||||||
|
nbatch_fa == 32 ? 5 :
|
||||||
|
nbatch_fa == 16 ? 4 :
|
||||||
|
nbatch_fa == 8 ? 3 : 0;
|
||||||
|
static_assert(log2_nbatch_fa != 0, "unexpected nbatch_fa");
|
||||||
|
|
||||||
|
constexpr float kq_max_offset = FATTN_KQ_MAX_OFFSET + (np == 1 ? (log2_nbatch_fa - 3) * 0.69314718f : 0.0f);
|
||||||
|
|
||||||
if constexpr (cols_per_warp == 8) {
|
if constexpr (cols_per_warp == 8) {
|
||||||
if (ncols2 > 1 || mask_h) {
|
if (ncols2 > 1 || mask_h) {
|
||||||
|
|
@ -591,7 +601,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
// Turing + Volta:
|
// Turing + Volta:
|
||||||
const int KQ_idx = l % 2;
|
const int KQ_idx = l % 2;
|
||||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + kq_max_offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -655,7 +665,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
// Turing + Volta:
|
// Turing + Volta:
|
||||||
const int KQ_idx = (l/2) % 2;
|
const int KQ_idx = (l/2) % 2;
|
||||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + kq_max_offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1429,8 +1439,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
|
|
||||||
if (!needs_fixup && !is_fixup) {
|
if (!needs_fixup && !is_fixup) {
|
||||||
const float KQ_rowsum_j = meta_j[1];
|
const float KQ_rowsum_j = meta_j[1];
|
||||||
dstk_val.x /= KQ_rowsum_j;
|
if (!(KQ_rowsum_j > 0.0f)) {
|
||||||
dstk_val.y /= KQ_rowsum_j;
|
dstk_val = make_float2(0.0f, 0.0f);
|
||||||
|
} else {
|
||||||
|
dstk_val.x /= KQ_rowsum_j;
|
||||||
|
dstk_val.y /= KQ_rowsum_j;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_fixup) {
|
if (is_fixup) {
|
||||||
|
|
|
||||||
|
|
@ -453,7 +453,11 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
float dst_val = VKQ[j_VKQ*D_padded + i];
|
float dst_val = VKQ[j_VKQ*D_padded + i];
|
||||||
if (gridDim.y == 1) {
|
if (gridDim.y == 1) {
|
||||||
dst_val /= KQ_rowsum_j;
|
if (!(KQ_rowsum_j > 0.0f)) {
|
||||||
|
dst_val = 0.0f;
|
||||||
|
} else {
|
||||||
|
dst_val /= KQ_rowsum_j;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
dst[j_dst_unrolled*D + i] = dst_val;
|
dst[j_dst_unrolled*D + i] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue