diff --git a/gemma/activations.h b/gemma/activations.h index a0627ae..37f663f 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -198,6 +198,12 @@ struct AttentionActivationsPtrs { } const ModelConfig& config; + + // For the matrices below, the batch_size dimension is really qbatch.Size() * + // token_batch_size, but in all known uses, one of those is 1. Specifically, + // during PrefillTBatch, it is prompt length (up to some max batch size) + // and otherwise it's qbatch.Size(). + // Query matrix of size batch_size x (q_heads * qkv_dim). MatPtrT q; // Query matrix of size batch_size x (q_heads * qkv_dim). diff --git a/gemma/attention.cc b/gemma/attention.cc index c2c984a..5880b50 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -187,12 +187,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar); const size_t qi = div_qbatch.Remainder(tq_idx); - const size_t batch_idx = div_qbatch.Divide(tq_idx); + const size_t token_idx = div_qbatch.Divide(tq_idx); auto& kv_cache = qbatch.KV(qi).kv_cache; // Find the token position in the query and calculate // the range of cache positions to attend to. - const size_t pos = qbatch.Pos(qi) + batch_idx; + const size_t pos = qbatch.Pos(qi) + token_idx; const size_t start_pos = StartPos(pos, activations.config, layer_idx); size_t last_pos = pos; const size_t prefix_end = qbatch.PrefixEnd(qi);