CUDA: fix 0.0f/0.0f for FA fixup

This commit is contained in:
Johannes Gäßler 2025-12-30 00:17:01 +01:00
parent 51a48720b8
commit 4e02ad7f09
1 changed files with 2 additions and 2 deletions

View File

@ -713,7 +713,7 @@ static __global__ void flash_attn_stream_k_fixup(
}
// Write back final result:
*dst = dst_val / rowsum;
*dst = dst_val == 0.0f && rowsum == 0.0f ? 0.0f : dst_val / rowsum;
}
template<int D> // D == head size
@ -766,7 +766,7 @@ static __global__ void flash_attn_combine_results(
VKQ_denominator += KQ_max_scale * meta[l].y;
}
dst[tid] = VKQ_numerator / VKQ_denominator;
dst[tid] = VKQ_numerator == 0.0f && VKQ_denominator == 0.0f ? 0.0f : VKQ_numerator / VKQ_denominator;
}
template <int DV, int ncols1, int ncols2>