mirror of https://github.com/google/gemma.cpp.git
Improve FlashAttention threading:
kFlat for RMSNorm (hierarchical is excessive), profiler zone naming improvements. PiperOrigin-RevId: 814144012
This commit is contained in:
parent
6098a022b3
commit
fe5a39990e
|
|
@ -252,7 +252,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
AttentionActivations& activations,
|
AttentionActivations& activations,
|
||||||
const QBatch& qbatch, const int flags,
|
const QBatch& qbatch, const int flags,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
PROFILER_ZONE("Gen.Attention.QKV");
|
static const auto zone = env.ctx.profiler.AddZone(
|
||||||
|
"Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive);
|
||||||
|
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||||
|
|
||||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
|
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
|
|
@ -325,7 +328,9 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||||
AttentionActivations& activations,
|
AttentionActivations& activations,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
PROFILER_ZONE("Gen.Attention.SumHeads");
|
static const auto zone = env.ctx.profiler.AddZone(
|
||||||
|
"Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive);
|
||||||
|
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
(void)layer_config; // For HWY_DASSERT
|
(void)layer_config; // For HWY_DASSERT
|
||||||
// att_weights and att_out are concatenated heads, each of length
|
// att_weights and att_out are concatenated heads, each of length
|
||||||
|
|
@ -358,8 +363,10 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||||
env.ctx);
|
env.ctx);
|
||||||
} else {
|
} else {
|
||||||
FlashAttention(num_tokens, /*target_parallelism=*/64, layer_idx, layer,
|
// * 2 does not help on Turin.
|
||||||
activations, qbatch, env.ctx);
|
FlashAttention(num_tokens,
|
||||||
|
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
|
||||||
|
layer_idx, layer, activations, qbatch, env.ctx);
|
||||||
}
|
}
|
||||||
SumHeads(layer, activations, env);
|
SumHeads(layer, activations, env);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -73,8 +73,9 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<float>& q_t,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
{
|
{
|
||||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
// Better than kFlat.
|
||||||
HierarchicalParallelFor(q_t.Rows(), ctx.pools, func);
|
ParallelFor(ParallelismStrategy::kHierarchical, q_t.Rows(), ctx,
|
||||||
|
/*cluster_idx=*/0, func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -111,8 +112,10 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
{
|
{
|
||||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
// kHierarchical is not worth the extra sync overhead because the tasks are
|
||||||
HierarchicalParallelFor(num_tokens * qbatch.Size(), ctx.pools, func);
|
// very lightweight.
|
||||||
|
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
|
||||||
|
/*cluster_idx=*/0, func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -722,7 +725,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
PROFILER_ZONE("Gen.FlashAttention.ForkJoin");
|
||||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||||
HierarchicalParallelFor(num_thread_tasks, ctx.pools, func);
|
HierarchicalParallelFor(num_thread_tasks, ctx.pools, func);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue