diff --git a/gemma/attention.cc b/gemma/attention.cc index 78eb496..c2c984a 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -265,9 +265,11 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, layer.qkv_einsum_w2.Rows())); for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; ++interleaved_idx) { + // Index into qbatch, within [0, qbatch.Size()] const size_t qi = div_qbatch.Remainder(interleaved_idx); - const size_t batch_idx = div_qbatch.Divide(interleaved_idx); - const size_t cache_pos = qbatch.Pos(qi) + batch_idx; + // Index along token sequence, within [0, num_tokens) + const size_t token_idx = div_qbatch.Divide(interleaved_idx); + const size_t cache_pos = qbatch.Pos(qi) + token_idx; // --seq_len must be large enough to avoid wraparound. HWY_DASSERT(cache_pos < activations.SeqLen()); @@ -288,8 +290,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const size_t head = task % kv_heads; const size_t interleaved_idx = task / kv_heads; const size_t qi = div_qbatch.Remainder(interleaved_idx); - const size_t batch_idx = div_qbatch.Divide(interleaved_idx); - const size_t cache_pos = qbatch.Pos(qi) + batch_idx; + const size_t token_idx = div_qbatch.Divide(interleaved_idx); + const size_t cache_pos = qbatch.Pos(qi) + token_idx; // --seq_len must be large enough to avoid wraparound. HWY_DASSERT(cache_pos < activations.SeqLen()); auto& kv_cache = qbatch.KV(qi).kv_cache;