mirror of https://github.com/google/gemma.cpp.git
Improve FlashAttention threading:
kFlat for RMSNorm (hierarchical is excessive), profiler zone naming improvements. PiperOrigin-RevId: 814144012
This commit is contained in:
parent
6098a022b3
commit
fe5a39990e
|
|
@ -252,7 +252,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
AttentionActivations& activations,
|
||||
const QBatch& qbatch, const int flags,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.QKV");
|
||||
static const auto zone = env.ctx.profiler.AddZone(
|
||||
"Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
|
@ -325,7 +328,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
|||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||
AttentionActivations& activations,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.SumHeads");
|
||||
static const auto zone = env.ctx.profiler.AddZone(
|
||||
"Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
(void)layer_config; // For HWY_DASSERT
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
|
|
@ -358,8 +363,10 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
|||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||
env.ctx);
|
||||
} else {
|
||||
FlashAttention(num_tokens, /*target_parallelism=*/64, layer_idx, layer,
|
||||
activations, qbatch, env.ctx);
|
||||
// * 2 does not help on Turin.
|
||||
FlashAttention(num_tokens,
|
||||
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
|
||||
layer_idx, layer, activations, qbatch, env.ctx);
|
||||
}
|
||||
SumHeads(layer, activations, env);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -73,8 +73,9 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
|||
}
|
||||
};
|
||||
{
|
||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||
HierarchicalParallelFor(q_t.Rows(), ctx.pools, func);
|
||||
// Better than kFlat.
|
||||
ParallelFor(ParallelismStrategy::kHierarchical, q_t.Rows(), ctx,
|
||||
/*cluster_idx=*/0, func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -111,8 +112,10 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
|||
}
|
||||
};
|
||||
{
|
||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||
HierarchicalParallelFor(num_tokens * qbatch.Size(), ctx.pools, func);
|
||||
// kHierarchical is not worth the extra sync overhead because the tasks are
|
||||
// very lightweight.
|
||||
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
|
||||
/*cluster_idx=*/0, func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -722,7 +725,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
};
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||
PROFILER_ZONE("Gen.FlashAttention.ForkJoin");
|
||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||
HierarchicalParallelFor(num_thread_tasks, ctx.pools, func);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue