diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 8dc82a9d3b..967b79ed3c 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -713,7 +713,7 @@ static __global__ void flash_attn_stream_k_fixup( } // Write back final result: - *dst = dst_val / rowsum; + *dst = dst_val == 0.0f && rowsum == 0.0f ? 0.0f : dst_val / rowsum; } template // D == head size @@ -766,7 +766,7 @@ static __global__ void flash_attn_combine_results( VKQ_denominator += KQ_max_scale * meta[l].y; } - dst[tid] = VKQ_numerator / VKQ_denominator; + dst[tid] = VKQ_numerator == 0.0f && VKQ_denominator == 0.0f ? 0.0f : VKQ_numerator / VKQ_denominator; } template