diff --git a/gemma/attention.cc b/gemma/attention.cc index f404674..576c0b7 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -252,7 +252,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, AttentionActivations& activations, const QBatch& qbatch, const int flags, MatMulEnv& env) { - PROFILER_ZONE("Gen.Attention.QKV"); + static const auto zone = env.ctx.profiler.AddZone( + "Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + const hwy::Divisor div_qbatch(qbatch.Size()); const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor(); const LayerConfig& layer_config = layer.layer_config; @@ -325,7 +328,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, AttentionActivations& activations, MatMulEnv& env) { - PROFILER_ZONE("Gen.Attention.SumHeads"); + static const auto zone = env.ctx.profiler.AddZone( + "Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); const LayerConfig& layer_config = layer.layer_config; (void)layer_config; // For HWY_DASSERT // att_weights and att_out are concatenated heads, each of length @@ -358,8 +363,10 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, env.ctx); } else { - FlashAttention(num_tokens, /*target_parallelism=*/64, layer_idx, layer, - activations, qbatch, env.ctx); + // * 2 does not help on Turin. + FlashAttention(num_tokens, + /*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1, + layer_idx, layer, activations, qbatch, env.ctx); } SumHeads(layer, activations, env); } diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index ddf2bcc..33ad725 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -73,8 +73,9 @@ static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, } }; { - // Full parallelism is helpful, SmallParallelFor is insufficient. - HierarchicalParallelFor(q_t.Rows(), ctx.pools, func); + // Better than kFlat. + ParallelFor(ParallelismStrategy::kHierarchical, q_t.Rows(), ctx, + /*cluster_idx=*/0, func); } } @@ -111,8 +112,10 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, } }; { - // Full parallelism is helpful, SmallParallelFor is insufficient. - HierarchicalParallelFor(num_tokens * qbatch.Size(), ctx.pools, func); + // kHierarchical is not worth the extra sync overhead because the tasks are + // very lightweight. + ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx, + /*cluster_idx=*/0, func); } } @@ -722,7 +725,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, }; { - PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); + PROFILER_ZONE("Gen.FlashAttention.ForkJoin"); // Full parallelism is helpful, SmallParallelFor is insufficient. HierarchicalParallelFor(num_thread_tasks, ctx.pools, func); }