diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 9fd739a..1faf6a3 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -305,6 +305,9 @@ class GemmaAttention { } } + // Self-extension + const hwy::Divisor& div_grp_size{ + static_cast(layer_config_.grp_size)}; // Apply positional encodings for K (and copy KV to cache if MHA). pool_.Run(0, layer_config_.kv_heads * num_interleaved, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { @@ -319,7 +322,6 @@ class GemmaAttention { head * layer_config_.qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; - const hwy::Divisor& div_grp_size { static_cast(layer_config_.grp_size) }; const size_t ngb_size = layer_config_.ngb_size; const bool self_extend = layer_config_.self_extend; @@ -328,7 +330,8 @@ class GemmaAttention { activations_.q.Batch(interleaved_idx) + head * q_stride_ + layer_config_.qkv_dim; - // When embedding position, we will use grouped key position + // In self-extend, when embedding position, + // we will use grouped key position if (self_extend && pos > ngb_size) { pos = div_grp_size.Divide(pos); } @@ -1484,7 +1487,7 @@ void GenerateBatchT(const ModelWeightsStorage& model, qbatch_size); QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], - qbatch_size); + qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); GenerateT(model, activations, runtime_config, qbatch_prompts, qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info);