Increased parallelism for RMSNormAndPositionalEncoding

PiperOrigin-RevId: 813738994
This commit is contained in:
Ray Smith 2025-10-01 07:10:40 -07:00 committed by Copybara-Service
parent 2f6cbde8ff
commit 6098a022b3
1 changed files with 19 additions and 18 deletions

View File

@ -87,31 +87,32 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
static const auto zone =
ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding");
const float query_scale = activations.query_scale;
const hwy::Divisor div_qbatch(qbatch.Size());
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
const size_t tq_idx = qbatch.Size() * task + qi;
// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + task;
float* HWY_RESTRICT q_row =
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), q_row,
layer.layer_config.qkv_dim, ctx.profiler, worker);
});
}
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
worker, pos, query_scale);
size_t qi = div_qbatch.Remainder(task);
size_t batch_idx = div_qbatch.Divide(task);
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT q_row =
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), q_row,
layer.layer_config.qkv_dim, ctx.profiler, worker);
});
}
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
worker, pos, query_scale);
}
};
{
// Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(num_tokens, ctx.pools, func);
HierarchicalParallelFor(num_tokens * qbatch.Size(), ctx.pools, func);
}
}