CUDA: fix 0.0f/0.0f for FA fixup
This commit is contained in:
parent
51a48720b8
commit
4e02ad7f09
|
|
@ -713,7 +713,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write back final result:
|
// Write back final result:
|
||||||
*dst = dst_val / rowsum;
|
*dst = dst_val == 0.0f && rowsum == 0.0f ? 0.0f : dst_val / rowsum;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int D> // D == head size
|
template<int D> // D == head size
|
||||||
|
|
@ -766,7 +766,7 @@ static __global__ void flash_attn_combine_results(
|
||||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[tid] = VKQ_numerator / VKQ_denominator;
|
dst[tid] = VKQ_numerator == 0.0f && VKQ_denominator == 0.0f ? 0.0f : VKQ_numerator / VKQ_denominator;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int DV, int ncols1, int ncols2>
|
template <int DV, int ncols1, int ncols2>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue