mirror of https://github.com/google/gemma.cpp.git
parent
1014ae9e2a
commit
14a9ecf21d
|
|
@ -43,6 +43,7 @@
|
|||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "gemma/flash_attention.h"
|
||||
#include "gemma/gemma-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -319,25 +320,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
});
|
||||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||
AttentionActivationsPtrs& activations,
|
||||
MatMulEnv& env) {
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
(void)layer_config; // For HWY_DASSERT
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// layer_config.qkv_dim. Thus the [num_interleaved,
|
||||
// layer_config.model_dim] matmul output is the sum over heads. Compare
|
||||
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
|
||||
// encoded)
|
||||
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
|
||||
layer_config.qkv_dim != 0);
|
||||
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
|
||||
activations.att_sums);
|
||||
}
|
||||
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||
|
|
|
|||
|
|
@ -183,6 +183,25 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
|||
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
|
||||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||
AttentionActivationsPtrs& activations,
|
||||
MatMulEnv& env) {
|
||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
(void)layer_config; // For HWY_DASSERT
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// layer_config.qkv_dim. Thus the [num_interleaved,
|
||||
// layer_config.model_dim] matmul output is the sum over heads. Compare
|
||||
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
|
||||
// encoded)
|
||||
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
|
||||
layer_config.qkv_dim != 0);
|
||||
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
|
||||
activations.att_sums);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Loading…
Reference in New Issue