From c9b8479f7d1dee327ce03abb829adb40c9512865 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Fri, 12 Sep 2025 07:47:36 -0700 Subject: [PATCH] Added zero-initialization to att_out. Re-enabled flash attention when HWY_NATIVE_DOT_BF16 is not available. PiperOrigin-RevId: 806284756 --- gemma/flash_attention.cc | 5 +++++ gemma/gemma.cc | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 40096d1..ba1de3e 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -256,6 +256,11 @@ void TileFlashAttention( using DI = hn::ScalableTag; const DI di; using VI = hn::Vec; + const int kVTileSize = hn::MaxLanes(df); + for (int i = 0; i < kVTileSize; ++i) { + hwy::ZeroBytes(att_out.Row(0) + out_offsets[i], + v.Cols() * sizeof(att_out.Row(0)[0])); + } VI lasts = hn::LoadU(di, last_pos); VF old_m = hn::Set(df, -std::numeric_limits::max() / 2.0f); VF old_d = hn::Zero(df); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 05583b3..778ecc6 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -75,7 +75,7 @@ void Attention(LayerAttentionType type, const size_t num_tokens, if (type == LayerAttentionType::kGemma) { // TODO: remove flag to enable FlashAttention. GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, - env, kAttentionUseOld); + env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0); } }