diff --git a/gemma/activations.h b/gemma/activations.h index 37f663f..021b2fd 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -25,6 +25,7 @@ #include "gemma/configs.h" // ModelConfig #include "gemma/gemma_args.h" // AttentionImpl +#include "gemma/kv_cache.h" #include "ops/ops.h" // CreateInvTimescale #include "util/basics.h" // BF16 #include "util/mat.h" // MatStorageT diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 8a9757b..473cbcc 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -425,9 +425,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, float scale = old_d * std::exp(old_max - m); old_d = hn::ReduceSum(df, x) + scale; old_max = m; - float one_over_d = 1.0f / old_d; - scale *= one_over_d; - x = hn::Mul(x, hn::Set(df, one_over_d)); + if (old_d > 0.0f) { + const float one_over_d = 1.0f / old_d; + scale *= one_over_d; + x = hn::Mul(x, hn::Set(df, one_over_d)); + } else { + scale = 0.0f; + x = hn::Zero(df); + } return scale; } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 5dd665d..1af520a 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -519,8 +519,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config, HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); } if (max_prompt_size > seq_len) { - HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.", - max_prompt_size); + HWY_ABORT( + "max_prompt_size = %zu, seq_len = %zu, increase --seq_len to at least " + "that.", + max_prompt_size, seq_len); } HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);