From 61dedf73eda1117991d619d0e9c1f47f725af297 Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Mon, 8 Dec 2025 08:00:00 -0800 Subject: [PATCH] Internal changes PiperOrigin-RevId: 841765739 --- gemma/activations.h | 1 + gemma/attention.cc | 5 ++--- gemma/attention.h | 2 ++ gemma/flash_attention.cc | 11 ++++++++--- gemma/gemma.cc | 6 ++++-- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/gemma/activations.h b/gemma/activations.h index 37f663f..021b2fd 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -25,6 +25,7 @@ #include "gemma/configs.h" // ModelConfig #include "gemma/gemma_args.h" // AttentionImpl +#include "gemma/kv_cache.h" #include "ops/ops.h" // CreateInvTimescale #include "util/basics.h" // BF16 #include "util/mat.h" // MatStorageT diff --git a/gemma/attention.cc b/gemma/attention.cc index 5880b50..9ad8c39 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -321,9 +321,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and // head_dim (`qkv_dim`) into output (`layer_out`). -static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, - AttentionActivationsPtrs& activations, - MatMulEnv& env) { +void SumHeads(const LayerWeightsPtrs& layer, + AttentionActivationsPtrs& activations, MatMulEnv& env) { GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads); const LayerConfig& layer_config = layer.layer_config; (void)layer_config; // For HWY_DASSERT diff --git a/gemma/attention.h b/gemma/attention.h index 60e6823..62e7132 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -51,6 +51,8 @@ namespace gcpp { const LayerWeightsPtrs& layer, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \ MatMulEnv& env, int flags); \ + void SumHeads(const LayerWeightsPtrs& layer, \ + AttentionActivationsPtrs& activations, MatMulEnv& env); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 8a9757b..473cbcc 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -425,9 +425,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, float scale = old_d * std::exp(old_max - m); old_d = hn::ReduceSum(df, x) + scale; old_max = m; - float one_over_d = 1.0f / old_d; - scale *= one_over_d; - x = hn::Mul(x, hn::Set(df, one_over_d)); + if (old_d > 0.0f) { + const float one_over_d = 1.0f / old_d; + scale *= one_over_d; + x = hn::Mul(x, hn::Set(df, one_over_d)); + } else { + scale = 0.0f; + x = hn::Zero(df); + } return scale; } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 5dd665d..1af520a 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -519,8 +519,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config, HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); } if (max_prompt_size > seq_len) { - HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.", - max_prompt_size); + HWY_ABORT( + "max_prompt_size = %zu, seq_len = %zu, increase --seq_len to at least " + "that.", + max_prompt_size, seq_len); } HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);