Improve FlashAttention threading:

kFlat for RMSNorm (hierarchical is excessive),
profiler zone naming improvements.

PiperOrigin-RevId: 814144012
This commit is contained in:
Jan Wassenberg 2025-10-02 02:36:29 -07:00 committed by Copybara-Service
parent 6098a022b3
commit fe5a39990e
2 changed files with 19 additions and 9 deletions

View File

@ -252,7 +252,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
AttentionActivations& activations,
const QBatch& qbatch, const int flags,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKV");
static const auto zone = env.ctx.profiler.AddZone(
"Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
const hwy::Divisor div_qbatch(qbatch.Size());
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
const LayerConfig& layer_config = layer.layer_config;
@ -325,7 +328,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
AttentionActivations& activations,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.SumHeads");
static const auto zone = env.ctx.profiler.AddZone(
"Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
const LayerConfig& layer_config = layer.layer_config;
(void)layer_config; // For HWY_DASSERT
// att_weights and att_out are concatenated heads, each of length
@ -358,8 +363,10 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
env.ctx);
} else {
FlashAttention(num_tokens, /*target_parallelism=*/64, layer_idx, layer,
activations, qbatch, env.ctx);
// * 2 does not help on Turin.
FlashAttention(num_tokens,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
layer_idx, layer, activations, qbatch, env.ctx);
}
SumHeads(layer, activations, env);
}

View File

@ -73,8 +73,9 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
}
};
{
// Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(q_t.Rows(), ctx.pools, func);
// Better than kFlat.
ParallelFor(ParallelismStrategy::kHierarchical, q_t.Rows(), ctx,
/*cluster_idx=*/0, func);
}
}
@ -111,8 +112,10 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
}
};
{
// Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(num_tokens * qbatch.Size(), ctx.pools, func);
// kHierarchical is not worth the extra sync overhead because the tasks are
// very lightweight.
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
/*cluster_idx=*/0, func);
}
}
@ -722,7 +725,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
};
{
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
PROFILER_ZONE("Gen.FlashAttention.ForkJoin");
// Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(num_thread_tasks, ctx.pools, func);
}