From 684a0444e9bb6ddba53c361d87def055235ca387 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Thu, 2 Oct 2025 08:14:37 -0700 Subject: [PATCH] Reduced parallelism for TransposeQ, making each thread read and write within its own cache lines PiperOrigin-RevId: 814241032 --- gemma/flash_attention.cc | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 77a4480..548c1aa 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -61,22 +61,30 @@ static constexpr size_t kNFx8HTileSize = 8; static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, const size_t qbatch_size, ThreadingContext& ctx) { static const auto zone = ctx.profiler.AddZone("Gen.Attention.TransposeQ"); + // Group floats by the number of floats in a cache line. + const size_t kNF = ctx.cache_info.LineBytes() / sizeof(float); const size_t num_heads = q.Cols() / q_t.Rows(); const size_t batch_size = q.Rows() / qbatch_size; const auto func = [&](const size_t task, size_t worker) HWY_ATTR { PROFILER_ZONE3(ctx.profiler, worker, zone); - float* HWY_RESTRICT qt_row = q_t.Row(task); - for (size_t qi = 0; qi < qbatch_size; ++qi) - for (size_t h = 0; h < num_heads; ++h) { - for (size_t b = 0; b < batch_size; ++b) { - qt_row[(qi * num_heads + h) * batch_size + b] = - q.Row(b * qbatch_size + qi)[h * q_t.Rows() + task]; + for (size_t lane = 0; lane < kNF; ++lane) { + size_t q_row = task * kNF + lane; + if (q_row >= q_t.Rows()) break; + float* HWY_RESTRICT qt_row = q_t.Row(q_row); + for (size_t qi = 0; qi < qbatch_size; ++qi) { + for (size_t h = 0; h < num_heads; ++h) { + for (size_t b = 0; b < batch_size; ++b) { + qt_row[(qi * num_heads + h) * batch_size + b] = + q.Row(b * qbatch_size + qi)[h * q_t.Rows() + q_row]; + } } } + } }; { // Better than kFlat. - ParallelFor(ParallelismStrategy::kHierarchical, q_t.Rows(), ctx, + size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF); + ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx, /*cluster_idx=*/0, func); } }