Internal changes

PiperOrigin-RevId: 842194766
This commit is contained in:
Krzysztof Rymski 2025-12-09 05:41:31 -08:00 committed by Copybara-Service
parent 14a9ecf21d
commit 64d700cab5
3 changed files with 13 additions and 5 deletions

View File

@ -25,6 +25,7 @@
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // AttentionImpl
#include "gemma/kv_cache.h"
#include "ops/ops.h" // CreateInvTimescale
#include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT

View File

@ -425,9 +425,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale;
old_max = m;
float one_over_d = 1.0f / old_d;
scale *= one_over_d;
x = hn::Mul(x, hn::Set(df, one_over_d));
if (old_d > 0.0f) {
const float one_over_d = 1.0f / old_d;
scale *= one_over_d;
x = hn::Mul(x, hn::Set(df, one_over_d));
} else {
scale = 0.0f;
x = hn::Zero(df);
}
return scale;
}

View File

@ -519,8 +519,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
}
if (max_prompt_size > seq_len) {
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.",
max_prompt_size);
HWY_ABORT(
"max_prompt_size = %zu, seq_len = %zu, increase --seq_len to at least "
"that.",
max_prompt_size, seq_len);
}
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);