mirror of https://github.com/google/gemma.cpp.git
parent
1014ae9e2a
commit
14a9ecf21d
|
|
@ -43,6 +43,7 @@
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
#include "gemma/flash_attention.h"
|
#include "gemma/flash_attention.h"
|
||||||
|
#include "gemma/gemma-inl.h"
|
||||||
#include "ops/ops-inl.h"
|
#include "ops/ops-inl.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
|
@ -319,25 +320,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
|
||||||
// head_dim (`qkv_dim`) into output (`layer_out`).
|
|
||||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
|
||||||
AttentionActivationsPtrs& activations,
|
|
||||||
MatMulEnv& env) {
|
|
||||||
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
|
||||||
(void)layer_config; // For HWY_DASSERT
|
|
||||||
// att_weights and att_out are concatenated heads, each of length
|
|
||||||
// layer_config.qkv_dim. Thus the [num_interleaved,
|
|
||||||
// layer_config.model_dim] matmul output is the sum over heads. Compare
|
|
||||||
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
|
|
||||||
// encoded)
|
|
||||||
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
|
|
||||||
layer_config.qkv_dim != 0);
|
|
||||||
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
|
|
||||||
activations.att_sums);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||||
|
|
|
||||||
|
|
@ -183,6 +183,25 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
||||||
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
|
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||||
|
// head_dim (`qkv_dim`) into output (`layer_out`).
|
||||||
|
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||||
|
AttentionActivationsPtrs& activations,
|
||||||
|
MatMulEnv& env) {
|
||||||
|
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads);
|
||||||
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
|
(void)layer_config; // For HWY_DASSERT
|
||||||
|
// att_weights and att_out are concatenated heads, each of length
|
||||||
|
// layer_config.qkv_dim. Thus the [num_interleaved,
|
||||||
|
// layer_config.model_dim] matmul output is the sum over heads. Compare
|
||||||
|
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
|
||||||
|
// encoded)
|
||||||
|
HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 &&
|
||||||
|
layer_config.qkv_dim != 0);
|
||||||
|
CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env,
|
||||||
|
activations.att_sums);
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue