mirror of https://github.com/google/gemma.cpp.git
Improve clarity of indices II
Sorry, didn't see this one before. PiperOrigin-RevId: 840218378
This commit is contained in:
parent
9348048885
commit
b510ba2ab2
|
|
@ -198,6 +198,12 @@ struct AttentionActivationsPtrs {
|
||||||
}
|
}
|
||||||
|
|
||||||
const ModelConfig& config;
|
const ModelConfig& config;
|
||||||
|
|
||||||
|
// For the matrices below, the batch_size dimension is really qbatch.Size() *
|
||||||
|
// token_batch_size, but in all known uses, one of those is 1. Specifically,
|
||||||
|
// during PrefillTBatch, it is prompt length (up to some max batch size)
|
||||||
|
// and otherwise it's qbatch.Size().
|
||||||
|
|
||||||
// Query matrix of size batch_size x (q_heads * qkv_dim).
|
// Query matrix of size batch_size x (q_heads * qkv_dim).
|
||||||
MatPtrT<float> q;
|
MatPtrT<float> q;
|
||||||
// Query matrix of size batch_size x (q_heads * qkv_dim).
|
// Query matrix of size batch_size x (q_heads * qkv_dim).
|
||||||
|
|
|
||||||
|
|
@ -187,12 +187,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar);
|
GCPP_ZONE(ctx, worker, Zones::kGenAttentionDotSoftmaxWeightedSumPar);
|
||||||
|
|
||||||
const size_t qi = div_qbatch.Remainder(tq_idx);
|
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||||
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
const size_t token_idx = div_qbatch.Divide(tq_idx);
|
||||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||||
|
|
||||||
// Find the token position in the query and calculate
|
// Find the token position in the query and calculate
|
||||||
// the range of cache positions to attend to.
|
// the range of cache positions to attend to.
|
||||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
const size_t pos = qbatch.Pos(qi) + token_idx;
|
||||||
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
|
const size_t start_pos = StartPos(pos, activations.config, layer_idx);
|
||||||
size_t last_pos = pos;
|
size_t last_pos = pos;
|
||||||
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue