From 4e02ad7f0986fa9b0cebb4a7a11934b5f639d410 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 30 Dec 2025 00:17:01 +0100 Subject: [PATCH] CUDA: fix 0.0f/0.0f for FA fixup --- ggml/src/ggml-cuda/fattn-common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 8dc82a9d3b..967b79ed3c 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -713,7 +713,7 @@ static __global__ void flash_attn_stream_k_fixup( } // Write back final result: - *dst = dst_val / rowsum; + *dst = dst_val == 0.0f && rowsum == 0.0f ? 0.0f : dst_val / rowsum; } template // D == head size @@ -766,7 +766,7 @@ static __global__ void flash_attn_combine_results( VKQ_denominator += KQ_max_scale * meta[l].y; } - dst[tid] = VKQ_numerator / VKQ_denominator; + dst[tid] = VKQ_numerator == 0.0f && VKQ_denominator == 0.0f ? 0.0f : VKQ_numerator / VKQ_denominator; } template