mirror of https://github.com/google/gemma.cpp.git
parent
6d3e2b6f73
commit
d2090fddf3
|
|
@ -265,9 +265,11 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
layer.qkv_einsum_w2.Rows()));
|
layer.qkv_einsum_w2.Rows()));
|
||||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
++interleaved_idx) {
|
++interleaved_idx) {
|
||||||
|
// Index into qbatch, within [0, qbatch.Size()]
|
||||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
// Index along token sequence, within [0, num_tokens)
|
||||||
const size_t cache_pos = qbatch.Pos(qi) + batch_idx;
|
const size_t token_idx = div_qbatch.Divide(interleaved_idx);
|
||||||
|
const size_t cache_pos = qbatch.Pos(qi) + token_idx;
|
||||||
// --seq_len must be large enough to avoid wraparound.
|
// --seq_len must be large enough to avoid wraparound.
|
||||||
HWY_DASSERT(cache_pos < activations.SeqLen());
|
HWY_DASSERT(cache_pos < activations.SeqLen());
|
||||||
|
|
||||||
|
|
@ -288,8 +290,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
const size_t head = task % kv_heads;
|
const size_t head = task % kv_heads;
|
||||||
const size_t interleaved_idx = task / kv_heads;
|
const size_t interleaved_idx = task / kv_heads;
|
||||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
const size_t token_idx = div_qbatch.Divide(interleaved_idx);
|
||||||
const size_t cache_pos = qbatch.Pos(qi) + batch_idx;
|
const size_t cache_pos = qbatch.Pos(qi) + token_idx;
|
||||||
// --seq_len must be large enough to avoid wraparound.
|
// --seq_len must be large enough to avoid wraparound.
|
||||||
HWY_DASSERT(cache_pos < activations.SeqLen());
|
HWY_DASSERT(cache_pos < activations.SeqLen());
|
||||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue