Increased parallelism for RMSNormAndPositionalEncoding

PiperOrigin-RevId: 813738994
This commit is contained in:
Ray Smith 2025-10-01 07:10:40 -07:00 committed by Copybara-Service
parent 2f6cbde8ff
commit 6098a022b3
1 changed files with 19 additions and 18 deletions

View File

@ -87,31 +87,32 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
static const auto zone = static const auto zone =
ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding"); ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding");
const float query_scale = activations.query_scale; const float query_scale = activations.query_scale;
const hwy::Divisor div_qbatch(qbatch.Size());
const auto func = [&](const size_t task, size_t worker) HWY_ATTR { const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
PROFILER_ZONE3(ctx.profiler, worker, zone); PROFILER_ZONE3(ctx.profiler, worker, zone);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) { size_t qi = div_qbatch.Remainder(task);
for (size_t h = 0; h < layer.layer_config.heads; ++h) { size_t batch_idx = div_qbatch.Divide(task);
const size_t tq_idx = qbatch.Size() * task + qi; for (size_t h = 0; h < layer.layer_config.heads; ++h) {
// Find the token position in the query and calculate const size_t tq_idx = qbatch.Size() * batch_idx + qi;
// the range of cache positions to attend to. // Find the token position in the query and calculate
const size_t pos = qbatch.Pos(qi) + task; // the range of cache positions to attend to.
float* HWY_RESTRICT q_row = const size_t pos = qbatch.Pos(qi) + batch_idx;
q.Row(tq_idx) + h * layer.layer_config.qkv_dim; float* HWY_RESTRICT q_row =
// Apply rope and scaling to Q. q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
if (layer.query_norm_scale.HasPtr()) { // Apply rope and scaling to Q.
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { if (layer.query_norm_scale.HasPtr()) {
RMSNormInplace(weights_t->PackedScale1(), q_row, CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
layer.layer_config.qkv_dim, ctx.profiler, worker); RMSNormInplace(weights_t->PackedScale1(), q_row,
}); layer.layer_config.qkv_dim, ctx.profiler, worker);
} });
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
worker, pos, query_scale);
} }
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
worker, pos, query_scale);
} }
}; };
{ {
// Full parallelism is helpful, SmallParallelFor is insufficient. // Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(num_tokens, ctx.pools, func); HierarchicalParallelFor(num_tokens * qbatch.Size(), ctx.pools, func);
} }
} }