Improve clarity of indices II

Sorry, didn't see this one before.

PiperOrigin-RevId: 840218378
This commit is contained in:
Martin Stolle 2025-12-04 06:32:56 -08:00 committed by Copybara-Service
parent 9348048885
commit b510ba2ab2
2 changed files with 8 additions and 2 deletions

View File

@ -198,6 +198,12 @@ struct AttentionActivationsPtrs {
}
const ModelConfig& config;
// For the matrices below, the batch_size dimension is really qbatch.Size() *
// token_batch_size, but in all known uses, one of those is 1. Specifically,
// during PrefillTBatch, it is prompt length (up to some max batch size)
// and otherwise it's qbatch.Size().
// Query matrix of size batch_size x (q_heads * qkv_dim).
MatPtrT<float> q;
// Query matrix of size batch_size x (q_heads * qkv_dim).

View File

@ -187,12 +187,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar);
const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_qbatch.Divide(tq_idx);
const size_t token_idx = div_qbatch.Divide(tq_idx);
auto& kv_cache = qbatch.KV(qi).kv_cache;
// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t pos = qbatch.Pos(qi) + token_idx;
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
size_t last_pos = pos;
const size_t prefix_end = qbatch.PrefixEnd(qi);