mirror of https://github.com/google/gemma.cpp.git
Improve clarity of indices II
Sorry, didn't see this one before. PiperOrigin-RevId: 840218378
This commit is contained in:
parent
9348048885
commit
b510ba2ab2
|
|
@ -198,6 +198,12 @@ struct AttentionActivationsPtrs {
|
|||
}
|
||||
|
||||
const ModelConfig& config;
|
||||
|
||||
// For the matrices below, the batch_size dimension is really qbatch.Size() *
|
||||
// token_batch_size, but in all known uses, one of those is 1. Specifically,
|
||||
// during PrefillTBatch, it is prompt length (up to some max batch size)
|
||||
// and otherwise it's qbatch.Size().
|
||||
|
||||
// Query matrix of size batch_size x (q_heads * qkv_dim).
|
||||
MatPtrT<float> q;
|
||||
// Query matrix of size batch_size x (q_heads * qkv_dim).
|
||||
|
|
|
|||
|
|
@ -187,12 +187,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar);
|
||||
|
||||
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
||||
const size_t token_idx = div_qbatch.Divide(tq_idx);
|
||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||
|
||||
// Find the token position in the query and calculate
|
||||
// the range of cache positions to attend to.
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
const size_t pos = qbatch.Pos(qi) + token_idx;
|
||||
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
|
||||
size_t last_pos = pos;
|
||||
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
||||
|
|
|
|||
Loading…
Reference in New Issue