mirror of https://github.com/google/gemma.cpp.git
parent
5a6895c609
commit
61dedf73ed
|
|
@ -25,6 +25,7 @@
|
||||||
|
|
||||||
#include "gemma/configs.h" // ModelConfig
|
#include "gemma/configs.h" // ModelConfig
|
||||||
#include "gemma/gemma_args.h" // AttentionImpl
|
#include "gemma/gemma_args.h" // AttentionImpl
|
||||||
|
#include "gemma/kv_cache.h"
|
||||||
#include "ops/ops.h" // CreateInvTimescale
|
#include "ops/ops.h" // CreateInvTimescale
|
||||||
#include "util/basics.h" // BF16
|
#include "util/basics.h" // BF16
|
||||||
#include "util/mat.h" // MatStorageT
|
#include "util/mat.h" // MatStorageT
|
||||||
|
|
|
||||||
|
|
@ -321,9 +321,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
|
|
||||||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
void SumHeads(const LayerWeightsPtrs& layer,
|
||||||
AttentionActivationsPtrs& activations,
|
AttentionActivationsPtrs& activations, MatMulEnv& env) {
|
||||||
MatMulEnv& env) {
|
|
||||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
(void)layer_config; // For HWY_DASSERT
|
(void)layer_config; // For HWY_DASSERT
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,8 @@ namespace gcpp {
|
||||||
const LayerWeightsPtrs& layer, \
|
const LayerWeightsPtrs& layer, \
|
||||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||||
MatMulEnv& env, int flags); \
|
MatMulEnv& env, int flags); \
|
||||||
|
void SumHeads(const LayerWeightsPtrs& layer, \
|
||||||
|
AttentionActivationsPtrs& activations, MatMulEnv& env); \
|
||||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
} // namespace NAMESPACE
|
} // namespace NAMESPACE
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -425,9 +425,14 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
||||||
float scale = old_d * std::exp(old_max - m);
|
float scale = old_d * std::exp(old_max - m);
|
||||||
old_d = hn::ReduceSum(df, x) + scale;
|
old_d = hn::ReduceSum(df, x) + scale;
|
||||||
old_max = m;
|
old_max = m;
|
||||||
float one_over_d = 1.0f / old_d;
|
if (old_d > 0.0f) {
|
||||||
|
const float one_over_d = 1.0f / old_d;
|
||||||
scale *= one_over_d;
|
scale *= one_over_d;
|
||||||
x = hn::Mul(x, hn::Set(df, one_over_d));
|
x = hn::Mul(x, hn::Set(df, one_over_d));
|
||||||
|
} else {
|
||||||
|
scale = 0.0f;
|
||||||
|
x = hn::Zero(df);
|
||||||
|
}
|
||||||
return scale;
|
return scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -519,8 +519,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
|
||||||
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
||||||
}
|
}
|
||||||
if (max_prompt_size > seq_len) {
|
if (max_prompt_size > seq_len) {
|
||||||
HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.",
|
HWY_ABORT(
|
||||||
max_prompt_size);
|
"max_prompt_size = %zu, seq_len = %zu, increase --seq_len to at least "
|
||||||
|
"that.",
|
||||||
|
max_prompt_size, seq_len);
|
||||||
}
|
}
|
||||||
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue