mirror of https://github.com/google/gemma.cpp.git
Increased parallelism for RMSNormAndPositionalEncoding
PiperOrigin-RevId: 813738994
This commit is contained in:
parent
2f6cbde8ff
commit
6098a022b3
|
|
@ -87,31 +87,32 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
|
|||
static const auto zone =
|
||||
ctx.profiler.AddZone("Gen.Attention.RMSNormAndPositionalEncoding");
|
||||
const float query_scale = activations.query_scale;
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const auto func = [&](const size_t task, size_t worker) HWY_ATTR {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
|
||||
const size_t tq_idx = qbatch.Size() * task + qi;
|
||||
// Find the token position in the query and calculate
|
||||
// the range of cache positions to attend to.
|
||||
const size_t pos = qbatch.Pos(qi) + task;
|
||||
float* HWY_RESTRICT q_row =
|
||||
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
|
||||
// Apply rope and scaling to Q.
|
||||
if (layer.query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), q_row,
|
||||
layer.layer_config.qkv_dim, ctx.profiler, worker);
|
||||
});
|
||||
}
|
||||
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
|
||||
worker, pos, query_scale);
|
||||
size_t qi = div_qbatch.Remainder(task);
|
||||
size_t batch_idx = div_qbatch.Divide(task);
|
||||
for (size_t h = 0; h < layer.layer_config.heads; ++h) {
|
||||
const size_t tq_idx = qbatch.Size() * batch_idx + qi;
|
||||
// Find the token position in the query and calculate
|
||||
// the range of cache positions to attend to.
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
float* HWY_RESTRICT q_row =
|
||||
q.Row(tq_idx) + h * layer.layer_config.qkv_dim;
|
||||
// Apply rope and scaling to Q.
|
||||
if (layer.query_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.query_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), q_row,
|
||||
layer.layer_config.qkv_dim, ctx.profiler, worker);
|
||||
});
|
||||
}
|
||||
PositionalEncodingQK(q_row, layer_idx, layer, activations, ctx.profiler,
|
||||
worker, pos, query_scale);
|
||||
}
|
||||
};
|
||||
{
|
||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||
HierarchicalParallelFor(num_tokens, ctx.pools, func);
|
||||
HierarchicalParallelFor(num_tokens * qbatch.Size(), ctx.pools, func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue