From 6098a022b3fca21b54065a70dd6e341134623860 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Wed, 1 Oct 2025 07:10:40 -0700 Subject: [PATCH] Increased parallelism for RMSNormAndPositionalEncoding PiperOrigin-RevId: 813738994 --- gemma/flash_attention.cc | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index b93b58f..ddf2bcc 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -87,31 +87,32 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, static const auto zone = ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding"); const float query_scale = activations.query_scale; + const hwy::Divisor div_qbatch(qbatch.Size()); const auto func = [&](const size_t task, size_t worker) HWY_ATTR { PROFILER_ZONE3(ctx.profiler, worker, zone); - for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - for (size_t h = 0; h < layer.layer_config.heads; ++h) { - const size_t tq_idx = qbatch.Size() * task + qi; - // Find the token position in the query and calculate - // the range of cache positions to attend to. - const size_t pos = qbatch.Pos(qi) + task; - float* HWY_RESTRICT q_row = - q.Row(tq_idx) + h * layer.layer_config.qkv_dim; - // Apply rope and scaling to Q. - if (layer.query_norm_scale.HasPtr()) { - CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { - RMSNormInplace(weights_t->PackedScale1(), q_row, - layer.layer_config.qkv_dim, ctx.profiler, worker); - }); - } - PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler, - worker, pos, query_scale); + size_t qi = div_qbatch.Remainder(task); + size_t batch_idx = div_qbatch.Divide(task); + for (size_t h = 0; h < layer.layer_config.heads; ++h) { + const size_t tq_idx = qbatch.Size() * batch_idx + qi; + // Find the token position in the query and calculate + // the range of cache positions to attend to. + const size_t pos = qbatch.Pos(qi) + batch_idx; + float* HWY_RESTRICT q_row = + q.Row(tq_idx) + h * layer.layer_config.qkv_dim; + // Apply rope and scaling to Q. + if (layer.query_norm_scale.HasPtr()) { + CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), q_row, + layer.layer_config.qkv_dim, ctx.profiler, worker); + }); } + PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler, + worker, pos, query_scale); } }; { // Full parallelism is helpful, SmallParallelFor is insufficient. - HierarchicalParallelFor(num_tokens, ctx.pools, func); + HierarchicalParallelFor(num_tokens * qbatch.Size(), ctx.pools, func); } }