Added zero-initialization to att_out.

Re-enabled flash attention when HWY_NATIVE_DOT_BF16 is not available.

PiperOrigin-RevId: 806284756
This commit is contained in:
Ray Smith 2025-09-12 07:47:36 -07:00 committed by Copybara-Service
parent 2695aab5d2
commit c9b8479f7d
2 changed files with 6 additions and 1 deletions

View File

@ -256,6 +256,11 @@ void TileFlashAttention(
using DI = hn::ScalableTag<uint32_t>;
const DI di;
using VI = hn::Vec<DI>;
const int kVTileSize = hn::MaxLanes(df);
for (int i = 0; i < kVTileSize; ++i) {
hwy::ZeroBytes(att_out.Row(0) + out_offsets[i],
v.Cols() * sizeof(att_out.Row(0)[0]));
}
VI lasts = hn::LoadU(di, last_pos);
VF old_m = hn::Set(df, -std::numeric_limits<float>::max() / 2.0f);
VF old_d = hn::Zero(df);

View File

@ -75,7 +75,7 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
if (type == LayerAttentionType::kGemma) {
// TODO: remove flag to enable FlashAttention.
GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch,
env, kAttentionUseOld);
env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0);
}
}