diff --git a/gemma/attention.cc b/gemma/attention.cc index deb85c0..eccfd25 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -218,7 +218,8 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); - SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, + constexpr size_t offset = 0; // placeholder, do not remove + SingleDotSoftmaxWeightedSum(pos + offset, start_pos, last_pos, q, k, v, query_norm_scale, layer_idx, activations, att, att_out, sm_options, ctx, worker); }; @@ -313,8 +314,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, }); } + constexpr size_t offset = 0; // placeholder, do not remove PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker, - cache_pos, /*mul=*/1.0f); + cache_pos + offset, + /*mul=*/1.0f); CompressPerThread tls; Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); }); diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 473cbcc..7432f7b 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -111,7 +111,9 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, const size_t tq_idx = qbatch.Size() * batch_idx + qi; // Find the token position in the query and calculate // the range of cache positions to attend to. - const size_t pos = qbatch.Pos(qi) + batch_idx; + constexpr size_t offset = 0; // placeholder, do not remove + const size_t pos = + qbatch.Pos(qi) + batch_idx + offset; float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim; // Apply rope and scaling to Q. if (query_norm_scale.HasPtr()) {