diff --git a/gemma/activations.h b/gemma/activations.h index d766ef7..6f615cf 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -74,11 +74,11 @@ struct Activations { is_griffin ? Extents2D(batch_size, config.model_dim) : none_, MatPadding::kPacked), - inv_timescale( - CreateInvTimescale(env->ctx.allocator, layer_config.qkv_dim, - layer_config.post_qk == PostQKType::HalfRope)), + inv_timescale(CreateInvTimescale( + ThreadingContext::Get().allocator, layer_config.qkv_dim, + layer_config.post_qk == PostQKType::HalfRope)), inv_timescale_global(CreateInvTimescale( - env->ctx.allocator, layer_config.qkv_dim, + ThreadingContext::Get().allocator, layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), env(env) { diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 009d5a5..29c898b 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -503,14 +503,29 @@ class GemmaAttention { const LayerWeightsPtrs* layer_weights, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) : GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer, - activations, layer_weights, div_seq_len, kv_caches) {} + activations, layer_weights, div_seq_len, kv_caches, + activations.env->ctx) {} // Constructor with default initialization to 0 for queries_prefix_end. GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer, Activations& activations, const LayerWeightsPtrs* layer_weights, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) : GemmaAttention(queries_pos, nullptr, num_tokens, layer, activations, - layer_weights, div_seq_len, kv_caches) {} + layer_weights, div_seq_len, kv_caches, + activations.env->ctx) {} + // Constructor with an explicit ThreadingContext. This is needed for + // experimental code that invokes methods that do not use `activations.env`. + // Callers should not have to construct an `activations.env` just to pass in + // the threading context. + GemmaAttention(const QueriesPos& queries_pos, + const QueriesPos& queries_prefix_end, size_t num_tokens, + size_t layer, Activations& activations, + const LayerWeightsPtrs* layer_weights, + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, + ThreadingContext& ctx) + : GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer, + activations, layer_weights, div_seq_len, kv_caches, + ctx) {} // Full attention computation in three steps. HWY_INLINE void operator()() { @@ -526,7 +541,8 @@ class GemmaAttention { const QueriesPos* queries_prefix_end, size_t num_tokens, size_t layer, Activations& activations, const LayerWeightsPtrs* layer_weights, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, + ThreadingContext& ctx) : queries_pos_(queries_pos), num_queries_(queries_pos.size()), num_tokens_(num_tokens), @@ -540,8 +556,8 @@ class GemmaAttention { layer_weights_(*layer_weights), div_seq_len_(div_seq_len), kv_caches_(kv_caches), - allocator_(activations.env->ctx.allocator), - pool_(activations.env->ctx.pools.Pool(0)) { + allocator_(ctx.allocator), + pool_(ctx.pools.Pool(0)) { HWY_DASSERT(num_queries_ <= kv_caches_.size()); HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0, "query heads must be a multiple of key-value heads");