This commit is contained in:
copybara-service[bot] 2025-12-09 09:55:42 +00:00 committed by GitHub
commit 1a12c4d1a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 17 additions and 8 deletions

View File

@ -25,6 +25,7 @@
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // AttentionImpl #include "gemma/gemma_args.h" // AttentionImpl
#include "gemma/kv_cache.h"
#include "ops/ops.h" // CreateInvTimescale #include "ops/ops.h" // CreateInvTimescale
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT #include "util/mat.h" // MatStorageT

View File

@ -321,9 +321,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and // Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`). // head_dim (`qkv_dim`) into output (`layer_out`).
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivationsPtrs& activations, AttentionActivationsPtrs& activations, MatMulEnv& env) {
MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads); GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
(void)layer_config; // For HWY_DASSERT (void)layer_config; // For HWY_DASSERT

View File

@ -51,6 +51,8 @@ namespace gcpp {
const LayerWeightsPtrs& layer, \ const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \ AttentionActivationsPtrs& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \ MatMulEnv& env, int flags); \
void SumHeads(const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, MatMulEnv& env); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

View File

@ -425,9 +425,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
float scale = old_d * std::exp(old_max - m); float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale; old_d = hn::ReduceSum(df, x) + scale;
old_max = m; old_max = m;
float one_over_d = 1.0f / old_d; if (old_d > 0.0f) {
scale *= one_over_d; const float one_over_d = 1.0f / old_d;
x = hn::Mul(x, hn::Set(df, one_over_d)); scale *= one_over_d;
x = hn::Mul(x, hn::Set(df, one_over_d));
} else {
scale = 0.0f;
x = hn::Zero(df);
}
return scale; return scale;
} }

View File

@ -519,8 +519,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
} }
if (max_prompt_size > seq_len) { if (max_prompt_size > seq_len) {
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.", HWY_ABORT(
max_prompt_size); "max_prompt_size = %zu, seq_len = %zu, increase --seq_len to at least "
"that.",
max_prompt_size, seq_len);
} }
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len); HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);