CUDA: fix 0.0f/0.0f for FA fixup
This commit is contained in:
parent
51a48720b8
commit
4e02ad7f09
|
|
@ -713,7 +713,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
}
|
||||
|
||||
// Write back final result:
|
||||
*dst = dst_val / rowsum;
|
||||
*dst = dst_val == 0.0f && rowsum == 0.0f ? 0.0f : dst_val / rowsum;
|
||||
}
|
||||
|
||||
template<int D> // D == head size
|
||||
|
|
@ -766,7 +766,7 @@ static __global__ void flash_attn_combine_results(
|
|||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
}
|
||||
|
||||
dst[tid] = VKQ_numerator / VKQ_denominator;
|
||||
dst[tid] = VKQ_numerator == 0.0f && VKQ_denominator == 0.0f ? 0.0f : VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
template <int DV, int ncols1, int ncols2>
|
||||
|
|
|
|||
Loading…
Reference in New Issue