Adds a `GemmaAttention` constructor that takes an explicit `ThreadingContext`.

PiperOrigin-RevId: 757839682
This commit is contained in:
Biruk Mammo 2025-05-12 11:16:29 -07:00 committed by Copybara-Service
parent 45ad847a41
commit ba21e3beb4
2 changed files with 25 additions and 9 deletions

View File

@ -74,11 +74,11 @@ struct Activations {
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
inv_timescale(
CreateInvTimescale(env->ctx.allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale(CreateInvTimescale(
ThreadingContext::Get().allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale(
env->ctx.allocator, layer_config.qkv_dim,
ThreadingContext::Get().allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
env(env) {

View File

@ -503,14 +503,29 @@ class GemmaAttention {
const LayerWeightsPtrs<T>* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
: GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer,
activations, layer_weights, div_seq_len, kv_caches) {}
activations, layer_weights, div_seq_len, kv_caches,
activations.env->ctx) {}
// Constructor with default initialization to 0 for queries_prefix_end.
GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
Activations& activations,
const LayerWeightsPtrs<T>* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
: GemmaAttention(queries_pos, nullptr, num_tokens, layer, activations,
layer_weights, div_seq_len, kv_caches) {}
layer_weights, div_seq_len, kv_caches,
activations.env->ctx) {}
// Constructor with an explicit ThreadingContext. This is needed for
// experimental code that invokes methods that do not use `activations.env`.
// Callers should not have to construct an `activations.env` just to pass in
// the threading context.
GemmaAttention(const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, size_t num_tokens,
size_t layer, Activations& activations,
const LayerWeightsPtrs<T>* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
ThreadingContext& ctx)
: GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer,
activations, layer_weights, div_seq_len, kv_caches,
ctx) {}
// Full attention computation in three steps.
HWY_INLINE void operator()() {
@ -526,7 +541,8 @@ class GemmaAttention {
const QueriesPos* queries_prefix_end, size_t num_tokens,
size_t layer, Activations& activations,
const LayerWeightsPtrs<T>* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
ThreadingContext& ctx)
: queries_pos_(queries_pos),
num_queries_(queries_pos.size()),
num_tokens_(num_tokens),
@ -540,8 +556,8 @@ class GemmaAttention {
layer_weights_(*layer_weights),
div_seq_len_(div_seq_len),
kv_caches_(kv_caches),
allocator_(activations.env->ctx.allocator),
pool_(activations.env->ctx.pools.Pool(0)) {
allocator_(ctx.allocator),
pool_(ctx.pools.Pool(0)) {
HWY_DASSERT(num_queries_ <= kv_caches_.size());
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
"query heads must be a multiple of key-value heads");