mirror of https://github.com/google/gemma.cpp.git
Increased parallelism for RMSNormAndPositionalEncoding
PiperOrigin-RevId: 813738994
This commit is contained in:
parent
2f6cbde8ff
commit
6098a022b3
|
|
@ -87,14 +87,16 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||||
static const auto zone =
|
static const auto zone =
|
||||||
ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding");
|
ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding");
|
||||||
const float query_scale = activations.query_scale;
|
const float query_scale = activations.query_scale;
|
||||||
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
size_t qi = div_qbatch.Remainder(task);
|
||||||
|
size_t batch_idx = div_qbatch.Divide(task);
|
||||||
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
|
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
|
||||||
const size_t tq_idx = qbatch.Size() * task + qi;
|
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
|
||||||
// Find the token position in the query and calculate
|
// Find the token position in the query and calculate
|
||||||
// the range of cache positions to attend to.
|
// the range of cache positions to attend to.
|
||||||
const size_t pos = qbatch.Pos(qi) + task;
|
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||||
float* HWY_RESTRICT q_row =
|
float* HWY_RESTRICT q_row =
|
||||||
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
|
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
|
|
@ -107,11 +109,10 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
||||||
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
|
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
|
||||||
worker, pos, query_scale);
|
worker, pos, query_scale);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
};
|
};
|
||||||
{
|
{
|
||||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||||
HierarchicalParallelFor(num_tokens, ctx.pools, func);
|
HierarchicalParallelFor(num_tokens * qbatch.Size(), ctx.pools, func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue