internal change

PiperOrigin-RevId: 842205671
This commit is contained in:
Martin Stolle 2025-12-09 06:16:33 -08:00 committed by Copybara-Service
parent 64d700cab5
commit 9689fc82f9
2 changed files with 8 additions and 3 deletions

View File

@ -218,7 +218,8 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim)); MatPtrT<KV_t> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, constexpr size_t offset = 0; // placeholder, do not remove
SingleDotSoftmaxWeightedSum(pos + offset, start_pos, last_pos, q, k, v,
query_norm_scale, layer_idx, activations, att, query_norm_scale, layer_idx, activations, att,
att_out, sm_options, ctx, worker); att_out, sm_options, ctx, worker);
}; };
@ -313,8 +314,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
}); });
} }
constexpr size_t offset = 0; // placeholder, do not remove
PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker, PositionalEncodingQK(kv_f32, layer_idx, activations, env.ctx, worker,
cache_pos, /*mul=*/1.0f); cache_pos + offset,
/*mul=*/1.0f);
CompressPerThread tls; CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0); Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
}); });

View File

@ -111,7 +111,9 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
const size_t tq_idx = qbatch.Size() * batch_idx + qi; const size_t tq_idx = qbatch.Size() * batch_idx + qi;
// Find the token position in the query and calculate // Find the token position in the query and calculate
// the range of cache positions to attend to. // the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx; constexpr size_t offset = 0; // placeholder, do not remove
const size_t pos =
qbatch.Pos(qi) + batch_idx + offset;
float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim; float* HWY_RESTRICT q_row = q.Row(tq_idx) + h * layer_config.qkv_dim;
// Apply rope and scaling to Q. // Apply rope and scaling to Q.
if (query_norm_scale.HasPtr()) { if (query_norm_scale.HasPtr()) {