diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 1f5f1b9206..bef1773dca 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -716,7 +716,11 @@ static __global__ void flash_attn_stream_k_fixup( } // Write back final result: - *dst = dst_val / rowsum; + if (!(rowsum > 0.0f)) { + *dst = 0.0f; + } else { + *dst = dst_val / rowsum; + } } template // D == head size diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 9004d46904..88ca99f328 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -561,6 +561,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_max_new[col] = KQ_max[col]; } float KQ_rowsum_add[cols_per_thread] = {0.0f}; + constexpr int log2_nbatch_fa = + nbatch_fa == 256 ? 8 : + nbatch_fa == 128 ? 7 : + nbatch_fa == 64 ? 6 : + nbatch_fa == 32 ? 5 : + nbatch_fa == 16 ? 4 : + nbatch_fa == 8 ? 3 : 0; + static_assert(log2_nbatch_fa != 0, "unexpected nbatch_fa"); + + constexpr float kq_max_offset = FATTN_KQ_MAX_OFFSET + (np == 1 ? (log2_nbatch_fa - 3) * 0.69314718f : 0.0f); if constexpr (cols_per_warp == 8) { if (ncols2 > 1 || mask_h) { @@ -591,7 +601,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Turing + Volta: const int KQ_idx = l % 2; #endif // defined(AMD_WMMA_AVAILABLE) - KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); + KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + kq_max_offset); } } } @@ -655,7 +665,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Turing + Volta: const int KQ_idx = (l/2) % 2; #endif // defined(AMD_WMMA_AVAILABLE) - KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); + KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + kq_max_offset); } } } @@ -1429,8 +1439,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (!needs_fixup && !is_fixup) { const float KQ_rowsum_j = meta_j[1]; - dstk_val.x /= KQ_rowsum_j; - dstk_val.y /= KQ_rowsum_j; + if (!(KQ_rowsum_j > 0.0f)) { + dstk_val = make_float2(0.0f, 0.0f); + } else { + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } } if (is_fixup) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 8694fd06c7..b1035bfaba 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -453,7 +453,11 @@ static __global__ void flash_attn_ext_f16( } float dst_val = VKQ[j_VKQ*D_padded + i]; if (gridDim.y == 1) { - dst_val /= KQ_rowsum_j; + if (!(KQ_rowsum_j > 0.0f)) { + dst_val = 0.0f; + } else { + dst_val /= KQ_rowsum_j; + } } dst[j_dst_unrolled*D + i] = dst_val; }