Improve clarity of indices

PiperOrigin-RevId: 839805634
This commit is contained in:
Martin Stolle 2025-12-03 10:10:47 -08:00 committed by Copybara-Service
parent 6d3e2b6f73
commit d2090fddf3
1 changed files with 6 additions and 4 deletions

View File

@ -265,9 +265,11 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
layer.qkv_einsum_w2.Rows())); layer.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) { ++interleaved_idx) {
// Index into qbatch, within [0, qbatch.Size()]
const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx); // Index along token sequence, within [0, num_tokens)
const size_t cache_pos = qbatch.Pos(qi) + batch_idx; const size_t token_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos = qbatch.Pos(qi) + token_idx;
// --seq_len must be large enough to avoid wraparound. // --seq_len must be large enough to avoid wraparound.
HWY_DASSERT(cache_pos < activations.SeqLen()); HWY_DASSERT(cache_pos < activations.SeqLen());
@ -288,8 +290,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const size_t head = task % kv_heads; const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads; const size_t interleaved_idx = task / kv_heads;
const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t token_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos = qbatch.Pos(qi) + batch_idx; const size_t cache_pos = qbatch.Pos(qi) + token_idx;
// --seq_len must be large enough to avoid wraparound. // --seq_len must be large enough to avoid wraparound.
HWY_DASSERT(cache_pos < activations.SeqLen()); HWY_DASSERT(cache_pos < activations.SeqLen());
auto& kv_cache = qbatch.KV(qi).kv_cache; auto& kv_cache = qbatch.KV(qi).kv_cache;