mirror of https://github.com/google/gemma.cpp.git
Adds a `GemmaAttention` constructor that takes an explicit `ThreadingContext`.
PiperOrigin-RevId: 757839682
This commit is contained in:
parent
45ad847a41
commit
ba21e3beb4
|
|
@ -74,11 +74,11 @@ struct Activations {
|
|||
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
|
||||
MatPadding::kPacked),
|
||||
|
||||
inv_timescale(
|
||||
CreateInvTimescale(env->ctx.allocator, layer_config.qkv_dim,
|
||||
layer_config.post_qk == PostQKType::HalfRope)),
|
||||
inv_timescale(CreateInvTimescale(
|
||||
ThreadingContext::Get().allocator, layer_config.qkv_dim,
|
||||
layer_config.post_qk == PostQKType::HalfRope)),
|
||||
inv_timescale_global(CreateInvTimescale(
|
||||
env->ctx.allocator, layer_config.qkv_dim,
|
||||
ThreadingContext::Get().allocator, layer_config.qkv_dim,
|
||||
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
|
||||
|
||||
env(env) {
|
||||
|
|
|
|||
|
|
@ -503,14 +503,29 @@ class GemmaAttention {
|
|||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
|
||||
: GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer,
|
||||
activations, layer_weights, div_seq_len, kv_caches) {}
|
||||
activations, layer_weights, div_seq_len, kv_caches,
|
||||
activations.env->ctx) {}
|
||||
// Constructor with default initialization to 0 for queries_prefix_end.
|
||||
GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
|
||||
Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
|
||||
: GemmaAttention(queries_pos, nullptr, num_tokens, layer, activations,
|
||||
layer_weights, div_seq_len, kv_caches) {}
|
||||
layer_weights, div_seq_len, kv_caches,
|
||||
activations.env->ctx) {}
|
||||
// Constructor with an explicit ThreadingContext. This is needed for
|
||||
// experimental code that invokes methods that do not use `activations.env`.
|
||||
// Callers should not have to construct an `activations.env` just to pass in
|
||||
// the threading context.
|
||||
GemmaAttention(const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, size_t num_tokens,
|
||||
size_t layer, Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
ThreadingContext& ctx)
|
||||
: GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer,
|
||||
activations, layer_weights, div_seq_len, kv_caches,
|
||||
ctx) {}
|
||||
|
||||
// Full attention computation in three steps.
|
||||
HWY_INLINE void operator()() {
|
||||
|
|
@ -526,7 +541,8 @@ class GemmaAttention {
|
|||
const QueriesPos* queries_prefix_end, size_t num_tokens,
|
||||
size_t layer, Activations& activations,
|
||||
const LayerWeightsPtrs<T>* layer_weights,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
ThreadingContext& ctx)
|
||||
: queries_pos_(queries_pos),
|
||||
num_queries_(queries_pos.size()),
|
||||
num_tokens_(num_tokens),
|
||||
|
|
@ -540,8 +556,8 @@ class GemmaAttention {
|
|||
layer_weights_(*layer_weights),
|
||||
div_seq_len_(div_seq_len),
|
||||
kv_caches_(kv_caches),
|
||||
allocator_(activations.env->ctx.allocator),
|
||||
pool_(activations.env->ctx.pools.Pool(0)) {
|
||||
allocator_(ctx.allocator),
|
||||
pool_(ctx.pools.Pool(0)) {
|
||||
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
|
||||
"query heads must be a multiple of key-value heads");
|
||||
|
|
|
|||
Loading…
Reference in New Issue