1.14x batch decode speedup: parallelize RMSNorm ops

Activations was over-parallelized, use single pool instead.
Also improve profiler zone annotations,
pass through worker args (for tracking concurrency), now non-optional.

PiperOrigin-RevId: 788790976
This commit is contained in:
Jan Wassenberg 2025-07-30 00:54:55 -07:00 committed by Copybara-Service
parent ac0d751d20
commit d1638587f0
12 changed files with 142 additions and 108 deletions

View File

@ -474,6 +474,7 @@ cc_library(
":matmul", ":matmul",
"//io", "//io",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler",
], ],
) )

View File

@ -85,7 +85,7 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) { void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) {
Softmax(logits, vocab_size); Softmax(logits, vocab_size, /*worker=*/0);
} }
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE

View File

@ -29,7 +29,6 @@
#include "util/allocator.h" // Allocator #include "util/allocator.h" // Allocator
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT #include "util/mat.h" // MatStorageT
#include "hwy/profiler.h"
namespace gcpp { namespace gcpp {
@ -182,8 +181,8 @@ struct Activations {
// Note that BindC on any MatMul output considerably slows down Prefill. // Note that BindC on any MatMul output considerably slows down Prefill.
} }
// Negligible CPU time.
void SetBatchSize(size_t batch_size) { void SetBatchSize(size_t batch_size) {
PROFILER_ZONE("SetBatchSize");
x.OverrideRows(batch_size); x.OverrideRows(batch_size);
logits.OverrideRows(batch_size); logits.OverrideRows(batch_size);

View File

@ -150,7 +150,7 @@ void SingleDotSoftmaxWeightedSum(
// SoftMax with optional SoftCap yields "probabilities" in att. // SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len = HWY_MIN(last_pos + 1, seq_len); const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
MaybeLogitsSoftCap(att_cap, att, att_len, worker); MaybeLogitsSoftCap(att_cap, att, att_len, worker);
Softmax(att, att_len, /*temperature=*/1.0f, worker); Softmax(att, att_len, worker, /*temperature=*/1.0f);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out,
worker); worker);
@ -168,7 +168,6 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
AttentionActivations& activations, QBatch& qbatch, AttentionActivations& activations, QBatch& qbatch,
NestedPools& pools) { NestedPools& pools) {
PROFILER_ZONE("Gen.Attention.DotSoftmax.misc");
static const uint32_t HWY_MAYBE_UNUSED zone_id_par = static const uint32_t HWY_MAYBE_UNUSED zone_id_par =
PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par"); PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par");
@ -227,8 +226,13 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
layer, activations, att, att_out, worker); layer, activations, att, att_out, worker);
}; };
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, pools, {
/*pkg_idx=*/0, func); PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
const size_t pkg_idx = 0;
// Full parallelism is helpful, SmallParallelFor is insufficient.
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
pools, pkg_idx, func);
}
} }
// Different functions use different naming conventions for the number of // Different functions use different naming conventions for the number of

View File

@ -66,12 +66,13 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1,
template <class Mat> template <class Mat>
void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) { void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) {
using T = typename Mat::T; using T = typename Mat::T;
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0, const size_t pkg_idx = 0;
[&](uint64_t task, size_t worker) { SmallParallelFor(
// Cast to correct type so type deduction works. c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), // Cast to correct type so type deduction works.
static_cast<const T*>(nullptr), c1.Cols(), worker); Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
}); c1.Cols(), worker);
});
} }
template <class Mat> template <class Mat>
@ -79,18 +80,19 @@ HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1,
const Mat* c2, NestedPools& pools) { const Mat* c2, NestedPools& pools) {
using T = typename Mat::T; using T = typename Mat::T;
HWY_DASSERT(c1.SameShape(*c2)); HWY_DASSERT(c1.SameShape(*c2));
const size_t pkg_idx = 0;
if (c2 && c2->HasPtr()) { if (c2 && c2->HasPtr()) {
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0, SmallParallelFor(c1.Rows(), pools, pkg_idx,
[&](uint64_t task, size_t worker) { [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), Activation(activation, c1.Row(task), c2->Row(task),
worker); c1.Cols(), worker);
}); });
} else { // No multiplier } else { // No multiplier
ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0, SmallParallelFor(
[&](uint64_t task, size_t worker) { c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
static_cast<const T*>(nullptr), c1.Cols(), worker); c1.Cols(), worker);
}); });
} }
} }
@ -98,17 +100,17 @@ template <typename T2, class LayerWeights>
HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other, HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
MatPtrT<float>& HWY_RESTRICT x, MatPtrT<float>& HWY_RESTRICT x,
const LayerWeights& layer, const LayerWeights& layer,
bool is_attention) { bool is_attention, ThreadingContext& ctx) {
// ResidualType::Add // ResidualType::Add
AddFromBatched(other, x); AddFromBatched(other, x, ctx);
} }
template <typename InOutT> template <typename InOutT>
void PostNorm(PostNormType post_norm, const MatPtr& weights, void PostNorm(PostNormType post_norm, const MatPtr& weights,
MatPtrT<InOutT>& inout) { MatPtrT<InOutT>& inout, ThreadingContext& ctx) {
HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Rows() == 1);
if (post_norm == PostNormType::Scale) { if (post_norm == PostNormType::Scale) {
RMSNormInplaceBatched(weights, inout); RMSNormInplaceBatched(weights, inout, ctx);
} }
} }

View File

@ -92,19 +92,19 @@ static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
RMSNormBatched(activations.x, layer.pre_attention_norm_scale, RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
activations.attention.pre_att_rms_out); activations.attention.pre_att_rms_out, env.ctx);
Attention(layer_config.type, num_tokens, layer_idx, layer, activations, Attention(layer_config.type, num_tokens, layer_idx, layer, activations,
qbatch, env); qbatch, env);
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
activations.attention.att_sums); activations.attention.att_sums, env.ctx);
ResidualConnection(activations.attention.att_sums, activations.x, layer, ResidualConnection(activations.attention.att_sums, activations.x, layer,
/*is_attention=*/true); /*is_attention=*/true, env.ctx);
RMSNormBatched(activations.x, layer.pre_ffw_norm_scale, RMSNormBatched(activations.x, layer.pre_ffw_norm_scale,
activations.pre_ffw_rms_out); activations.pre_ffw_rms_out, env.ctx);
if (layer_config.type == LayerAttentionType::kVit) { if (layer_config.type == LayerAttentionType::kVit) {
FFWVit(layer, activations, env); FFWVit(layer, activations, env);
@ -113,10 +113,10 @@ static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
} }
PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale, PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale,
activations.ffw_out); activations.ffw_out, env.ctx);
ResidualConnection(activations.ffw_out, activations.x, layer, ResidualConnection(activations.ffw_out, activations.x, layer,
/*is_attention=*/false); /*is_attention=*/false, env.ctx);
} }
// Returns the scale value to use for the embedding (basically sqrt model_dim). // Returns the scale value to use for the embedding (basically sqrt model_dim).
@ -158,6 +158,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
const size_t model_dim = model_config.model_dim; const size_t model_dim = model_config.model_dim;
const float emb_scaling = EmbeddingScaling(model_dim); const float emb_scaling = EmbeddingScaling(model_dim);
const size_t worker = 0; // Not yet parallelized.
HWY_DASSERT(token >= 0); HWY_DASSERT(token >= 0);
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size)); HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
@ -173,7 +174,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi), DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
model_dim); model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim); MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, worker);
}); });
if (model_config.absolute_pe) { if (model_config.absolute_pe) {
@ -302,6 +303,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
} }
} }
// TODO: parallelize?
for (size_t qi = 0; qi < qbatch.Size(); ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi), EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
/*pos_in_prompt=*/0, config, weights, activations.x); /*pos_in_prompt=*/0, config, weights, activations.x);
@ -328,7 +330,7 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
Activations& activations, QBatch& qbatch, Activations& activations, QBatch& qbatch,
MatMulEnv& env, MatMulEnv& env,
hwy::BitSet4096<>& non_eos) { hwy::BitSet4096<>& non_eos) {
PROFILER_ZONE("Gen.Prefill"); PROFILER_ZONE("Gen.PrefillQ");
for (size_t qi = 0; qi < qbatch.Size(); ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
non_eos.Set(qi); non_eos.Set(qi);
@ -400,7 +402,7 @@ static void DecodeStepT(const ModelConfig& config,
Transformer(config, runtime_config, weights, activations, qbatch, env); Transformer(config, runtime_config, weights, activations, qbatch, env);
RMSNormInplaceBatched(weights.final_norm_scale, activations.x); RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
if (HWY_UNLIKELY(runtime_config.activations_observer)) { if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer( runtime_config.activations_observer(
@ -414,9 +416,10 @@ static void DecodeStepT(const ModelConfig& config,
/*add=*/nullptr, env, activations.logits); /*add=*/nullptr, env, activations.logits);
} }
PROFILER_ZONE("Gen.Softcap+Sample+Stream"); PROFILER_ZONE("Gen.Softcap+Sample+Stream");
const size_t worker = 0; // TODO: parallelize
non_eos.Foreach([&](size_t qi) { non_eos.Foreach([&](size_t qi) {
float* HWY_RESTRICT logits = activations.logits.Row(qi); float* HWY_RESTRICT logits = activations.logits.Row(qi);
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size); MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, worker);
const TokenAndProb tp = sample_token(logits, config.vocab_size); const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated(); timing_info.NotifyGenerated();
@ -430,10 +433,12 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
// If user provided a sample_func, use it. // If user provided a sample_func, use it.
if (runtime_config.sample_func) return runtime_config.sample_func; if (runtime_config.sample_func) return runtime_config.sample_func;
const size_t worker = 0; // TODO: parallelize
// Fast path for top-1 with no accept_token. // Fast path for top-1 with no accept_token.
if (runtime_config.top_k == 1 && !runtime_config.accept_token) { if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample Top1"); PROFILER_ZONE2(worker, "Gen.Sample Top1");
return Top1OfSoftmax(logits, vocab_size); return Top1OfSoftmax(logits, vocab_size);
}; };
} }
@ -444,7 +449,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
PROFILER_ZONE("Gen.Sample general"); PROFILER_ZONE("Gen.Sample general");
return FusedSoftmaxAndSampleTopK( return FusedSoftmaxAndSampleTopK(
logits, runtime_config.top_k, vocab_size, *runtime_config.gen, logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
runtime_config.temperature, runtime_config.accept_token); runtime_config.temperature, runtime_config.accept_token, worker);
}; };
} }

View File

@ -32,6 +32,7 @@
#include "util/mat.h" #include "util/mat.h"
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // HWY_ABORT #include "hwy/base.h" // HWY_ABORT
#include "hwy/profiler.h"
namespace gcpp { namespace gcpp {
@ -116,6 +117,7 @@ struct RuntimeConfig {
// If non-null, `batch_stream_token` is called for each token in the batch, // If non-null, `batch_stream_token` is called for each token in the batch,
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative. // otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
PROFILER_ZONE("Gen.StreamToken");
if (batch_stream_token) { if (batch_stream_token) {
return batch_stream_token(query_idx, pos, token, prob); return batch_stream_token(query_idx, pos, token, prob);
} }

View File

@ -90,12 +90,12 @@ class VitAttention {
ZeroInit(activations_.attention.att_out); ZeroInit(activations_.attention.att_out);
for (size_t head = 0; head < heads; ++head) { for (size_t head = 0; head < heads; ++head) {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
const size_t token = task; const size_t token = task;
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim; activations_.attention.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working // TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim); MulByConst(query_scale, q, qkv_dim, worker);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
}); });
@ -109,19 +109,19 @@ class VitAttention {
// this produces C, a (num_tokens_, seq_len) matrix of dot products // this produces C, a (num_tokens_, seq_len) matrix of dot products
CallMatMul(Q, K, nullptr, env_, C); CallMatMul(Q, K, nullptr, env_, C);
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
float* HWY_RESTRICT c = C.Row(task); float* HWY_RESTRICT c = C.Row(task);
Softmax(c, C.Cols()); Softmax(c, C.Cols(), worker);
}); });
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR {
size_t token = task; size_t token = task;
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations_.attention.att_out.Row(token) + head * qkv_dim; activations_.attention.att_out.Row(token) + head * qkv_dim;
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) + float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim; head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, worker);
} }
}); });
} }
@ -138,13 +138,13 @@ class VitAttention {
// Compute Q.K, softmax, and weighted V. // Compute Q.K, softmax, and weighted V.
pool_.Run(0, layer_config_.heads * num_tokens_, pool_.Run(0, layer_config_.heads * num_tokens_,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR { [&](uint64_t task, size_t worker) HWY_ATTR {
const size_t head = task % layer_config_.heads; const size_t head = task % layer_config_.heads;
const size_t token = task / layer_config_.heads; const size_t token = task / layer_config_.heads;
// Compute Q.K scores, which are "logits" stored in head_att. // Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim; activations_.attention.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim); MulByConst(query_scale, q, qkv_dim, worker);
float* HWY_RESTRICT head_att = float* HWY_RESTRICT head_att =
activations_.attention.att.Row(token) + head * seq_len; activations_.attention.att.Row(token) + head * seq_len;
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
@ -153,7 +153,7 @@ class VitAttention {
head_att[i] = Dot(q, k, qkv_dim); // score = q.k head_att[i] = Dot(q, k, qkv_dim); // score = q.k
} }
// SoftMax yields "probabilities" in head_att. // SoftMax yields "probabilities" in head_att.
Softmax(head_att, seq_len); Softmax(head_att, seq_len, worker);
// Compute weighted sum of v into att_out. // Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations_.attention.att_out.Row(token) + head * qkv_dim; activations_.attention.att_out.Row(token) + head * qkv_dim;
@ -161,7 +161,7 @@ class VitAttention {
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) + float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim; head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, worker);
} }
}); });
} }
@ -259,7 +259,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
VitAttention(num_tokens, layer_idx, activations, layer, env)(); VitAttention(num_tokens, layer_idx, activations, layer, env)();
// x = out["+sa"] = x + y // x = out["+sa"] = x + y
AddFromBatched(activations.attention.att_sums, x); AddFromBatched(activations.attention.att_sums, x, env.ctx);
// y = nn.LayerNorm()(x) // y = nn.LayerNorm()(x)
// y ~ pre_ffw_rms_out // y ~ pre_ffw_rms_out
@ -271,7 +271,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
FFWVit(layer, activations, env); FFWVit(layer, activations, env);
// x = out["+mlp"] = x + y // x = out["+mlp"] = x + y
AddFromBatched(activations.ffw_out, x); AddFromBatched(activations.ffw_out, x, env.ctx);
} }
// Gets the patches of the image and embeds them with the image embedding // Gets the patches of the image and embeds them with the image embedding
@ -303,7 +303,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
// Add position embeddings. // Add position embeddings.
CallUpcastedActivation(&weights.vit_img_pos_embedding, CallUpcastedActivation(&weights.vit_img_pos_embedding,
[&](const auto* weights_t) { [&](const auto* weights_t) {
AddFromBatched(*weights_t, activations.x); AddFromBatched(*weights_t, activations.x, env.ctx);
}); });
} }
@ -334,7 +334,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
// Apply soft embedding norm before input projection. // Apply soft embedding norm before input projection.
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0), RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0),
vit_model_dim); vit_model_dim, /*worker=*/0);
}); });
} }

View File

@ -875,9 +875,6 @@ class MMPerPackage {
inner_tasks_(config.InnerTasks()), inner_tasks_(config.InnerTasks()),
out_(config.Out()), out_(config.Out()),
line_bytes_(args.env->ctx.allocator.LineBytes()) { line_bytes_(args.env->ctx.allocator.LineBytes()) {
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DecompressA");
MMZone zone;
zone.MaybeEnter(pkg_idx, zone_id, args_);
A_ = DecompressA(A); A_ = DecompressA(A);
} }
@ -1119,8 +1116,14 @@ class MMPerPackage {
const size_t NBF = hn::Lanes(dbf); const size_t NBF = hn::Lanes(dbf);
static_assert(hwy::IsSameEither<TA, BF16, float>(), "Can seek"); static_assert(hwy::IsSameEither<TA, BF16, float>(), "Can seek");
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DecompressA");
const auto do_range = [&](const IndexRange& range_M, const auto do_range = [&](const IndexRange& range_M,
const IndexRange& range_K) HWY_ATTR { const IndexRange& range_K,
size_t worker) HWY_ATTR {
MMZone zone;
zone.MaybeEnter(worker, zone_id, args_);
const size_t col0 = range_K.begin(); const size_t col0 = range_K.begin();
const size_t cols = range_K.Num(); const size_t cols = range_K.Num();
// Must be a vector multiple, or the last range before row padding, // Must be a vector multiple, or the last range before row padding,
@ -1141,7 +1144,7 @@ class MMPerPackage {
switch (par_a) { switch (par_a) {
case MMParA::kNone: case MMParA::kNone:
do_range(all_M, all_K); do_range(all_M, all_K, /*worker=*/0);
break; break;
case MMParA::kK1: case MMParA::kK1:
case MMParA::kK2: case MMParA::kK2:
@ -1154,15 +1157,15 @@ class MMPerPackage {
args_.env->parallel.ForNP( args_.env->parallel.ForNP(
all_K, multiple_K, inner_tasks, pkg_idx_, all_K, multiple_K, inner_tasks, pkg_idx_,
[&](const IndexRange& range_K, size_t /*worker*/) { [&](const IndexRange& range_K, size_t worker) {
do_range(all_M, range_K); do_range(all_M, range_K, worker);
}); });
break; break;
} }
case MMParA::kM: case MMParA::kM:
args_.env->parallel.ForRangeMC( args_.env->parallel.ForRangeMC(
all_M, pkg_idx_, [&](size_t row_a, size_t /*worker*/) { all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
do_range(IndexRange(row_a, row_a + 1), all_K); do_range(IndexRange(row_a, row_a + 1), all_K, worker);
}); });
break; break;
} }
@ -1190,12 +1193,9 @@ class MMPerPackage {
// First call: generate candidates. // First call: generate candidates.
if (HWY_UNLIKELY(!autotune.HasCandidates())) { if (HWY_UNLIKELY(!autotune.HasCandidates())) {
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4}; const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
if (A.Rows() == 1) { std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
candidates.push_back(MMParA::kNone); other};
} else {
candidates.push_back(MMParA::kM);
}
autotune.SetCandidates(candidates); autotune.SetCandidates(candidates);
} }
@ -1279,7 +1279,8 @@ struct MMImpl {
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B, static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows, const MMArgs& args, RowPtrs<TC> C_rows, const MMArgs& args,
const MMConfig& config) { const MMConfig& config) {
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul"); PROFILER_ZONE("MM.DoMatMul");
static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul.PerPkg");
// Outermost loop: static NUMA-aware partition of B rows across packages. // Outermost loop: static NUMA-aware partition of B rows across packages.
args.env->parallel.ForPkg( args.env->parallel.ForPkg(
@ -1353,7 +1354,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
return &per_key; return &per_key;
} }
PROFILER_ZONE("Matmul.Autotune"); // From here, CPU time is negligible except DoMatMul.
// First call: enumerate all feasible configs. // First call: enumerate all feasible configs.
if (HWY_UNLIKELY(!tuner.HasCandidates())) { if (HWY_UNLIKELY(!tuner.HasCandidates())) {
@ -1364,7 +1365,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
HWY_ASSERT(N <= MMStorage::kMaxN); HWY_ASSERT(N <= MMStorage::kMaxN);
HWY_ASSERT(N % kNR == 0); HWY_ASSERT(N % kNR == 0);
// Negligible CPU time.
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR, tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR,
kNR, per_key.ranges_np, env.print_config)); kNR, per_key.ranges_np, env.print_config));
} }

View File

@ -208,7 +208,7 @@ template <typename XT, typename WT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs, const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs,
OT* HWY_RESTRICT out, const size_t size, OT* HWY_RESTRICT out, const size_t size,
const size_t HWY_MAYBE_UNUSED worker = 0) { const size_t HWY_MAYBE_UNUSED worker) {
PROFILER_ZONE2(worker, "ops.RMSNorm"); PROFILER_ZONE2(worker, "ops.RMSNorm");
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
@ -240,7 +240,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
template <typename WT, typename XT> template <typename WT, typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout, const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout,
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) { const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
PROFILER_ZONE2(worker, "ops.RMSNormInplace"); PROFILER_ZONE2(worker, "ops.RMSNormInplace");
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
@ -527,7 +527,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
template <typename XT> template <typename XT>
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size, const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size,
const HWY_MAYBE_UNUSED size_t worker = 0) { const HWY_MAYBE_UNUSED size_t worker) {
PROFILER_ZONE2(worker, "ops.AddFrom"); PROFILER_ZONE2(worker, "ops.AddFrom");
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
@ -570,29 +570,36 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
// Simple loops unless/until batch sizes are large enough to parallelize. // Simple loops unless/until batch sizes are large enough to parallelize.
template <typename XT, typename OT> template <typename XT, typename OT>
void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights, void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
MatPtrT<OT>& out) { MatPtrT<OT>& out, ThreadingContext& ctx) {
HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Rows() == 1);
HWY_DASSERT(weights.Cols() == activations.Cols()); HWY_DASSERT(weights.Cols() == activations.Cols());
HWY_DASSERT(activations.SameShape(out)); HWY_DASSERT(activations.SameShape(out));
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
for (size_t token_idx = 0; token_idx < activations.Rows(); ++token_idx) { const size_t pkg_idx = 0;
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0, SmallParallelFor(activations.Rows(), ctx.pools, pkg_idx,
out.Row(token_idx), activations.Cols()); [&](uint64_t token_idx, size_t worker) {
} RMSNorm(activations.Row(token_idx),
weights_t->PackedScale1(), 0, out.Row(token_idx),
activations.Cols(), worker);
});
}); });
} }
template <typename XT> template <typename XT>
void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout) { void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
ThreadingContext& ctx) {
HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Rows() == 1);
HWY_DASSERT(weights.Cols() == inout.Cols()); HWY_DASSERT(weights.Cols() == inout.Cols());
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
for (size_t token_idx = 0; token_idx < inout.Rows(); ++token_idx) { const size_t pkg_idx = 0;
RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx), SmallParallelFor(inout.Rows(), ctx.pools, pkg_idx,
inout.Cols()); [&](uint64_t token_idx, size_t worker) {
} RMSNormInplace(weights_t->PackedScale1(), 0,
inout.Row(token_idx), inout.Cols(),
worker);
});
}); });
} }
@ -614,18 +621,20 @@ void LayerNormBatched(const MatPtrT<XT>& x, const MatPtr& weight,
} }
template <typename XT> template <typename XT>
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
MatPtrT<float>& out) { ThreadingContext& ctx) {
HWY_DASSERT(out.SameShape(x)); HWY_DASSERT(out.SameShape(x));
for (size_t token_idx = 0; token_idx < out.Rows(); ++token_idx) { const size_t pkg_idx = 0;
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols()); SmallParallelFor(
} out.Rows(), ctx.pools, pkg_idx, [&](uint64_t token_idx, size_t worker) {
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), worker);
});
} }
template <typename XT> template <typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
const float c, XT* HWY_RESTRICT x, const size_t size, const float c, XT* HWY_RESTRICT x, const size_t size,
const HWY_MAYBE_UNUSED size_t worker = 0) { const HWY_MAYBE_UNUSED size_t worker) {
PROFILER_ZONE2(worker, "ops.MulByConst"); PROFILER_ZONE2(worker, "ops.MulByConst");
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
@ -666,7 +675,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(
template <typename XT, typename OT> template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) { const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
PROFILER_ZONE2(worker, "ops.MulByConstTo"); PROFILER_ZONE2(worker, "ops.MulByConstTo");
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
@ -708,7 +717,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
template <typename XT, typename OT> template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) { const size_t size, const HWY_MAYBE_UNUSED size_t worker) {
PROFILER_ZONE2(worker, "ops.MulByConstAndAdd"); PROFILER_ZONE2(worker, "ops.MulByConstAndAdd");
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
@ -754,8 +763,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
// See below for a specialized version for top-1 sampling. // See below for a specialized version for top-1 sampling.
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
float temperature = 1.0f, const size_t worker,
const HWY_MAYBE_UNUSED size_t worker = 0) { float temperature = 1.0f) {
PROFILER_ZONE2(worker, "ops.Softmax"); PROFILER_ZONE2(worker, "ops.Softmax");
HWY_DASSERT(size != 0); HWY_DASSERT(size != 0);
@ -797,7 +806,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const float sum_exp = Sum(d, x, size); const float sum_exp = Sum(d, x, size);
// Double-precision reciprocal does not appear to affect the results. // Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp; const float mul = 1.0f / sum_exp;
MulByConst(mul, x, size); MulByConst(mul, x, size, worker);
} }
// Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max /
@ -886,9 +895,9 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x,
return TokenAndProb{.token = argmax.token, .prob = prob}; return TokenAndProb{.token = argmax.token, .prob = prob};
} }
static HWY_NOINLINE void LogitsSoftCap( static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
const float cap, float* HWY_RESTRICT x, const size_t size, const size_t size,
const HWY_MAYBE_UNUSED size_t worker = 0) { const HWY_MAYBE_UNUSED size_t worker) {
PROFILER_ZONE2(worker, "ops.LogitsSoftCap"); PROFILER_ZONE2(worker, "ops.LogitsSoftCap");
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
@ -906,7 +915,7 @@ static HWY_NOINLINE void LogitsSoftCap(
// Calls LogitsSoftCap if cap != 0.0f. // Calls LogitsSoftCap if cap != 0.0f.
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
const float cap, float* HWY_RESTRICT x, const size_t size, const float cap, float* HWY_RESTRICT x, const size_t size,
const size_t worker = 0) { const size_t worker) {
if (cap != 0.0f) { if (cap != 0.0f) {
LogitsSoftCap(cap, x, size, worker); LogitsSoftCap(cap, x, size, worker);
} }
@ -991,7 +1000,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
template <typename TAcceptToken> template <typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
const float* HWY_RESTRICT logits, size_t k, size_t vocab_size, const float* HWY_RESTRICT logits, size_t k, size_t vocab_size,
std::mt19937& gen, float temperature, TAcceptToken& accept_token) { std::mt19937& gen, float temperature, TAcceptToken& accept_token,
size_t worker) {
// Softmax and sample top-K is equivalent to taking the top-K logits and // Softmax and sample top-K is equivalent to taking the top-K logits and
// sampling from the softmax of the top-K logits. The latter is faster as it // sampling from the softmax of the top-K logits. The latter is faster as it
// avoids computing the softmax of all logits. // avoids computing the softmax of all logits.
@ -1005,7 +1015,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
} }
size_t mask = token_logits.size(); size_t mask = token_logits.size();
Softmax(topk_logits.data(), mask, temperature); Softmax(topk_logits.data(), mask, worker, temperature);
auto distribution = std::discrete_distribution<int>( auto distribution = std::discrete_distribution<int>(
std::begin(topk_logits), std::begin(topk_logits) + mask); std::begin(topk_logits), std::begin(topk_logits) + mask);
int topk_sampled_index = distribution(gen); int topk_sampled_index = distribution(gen);

View File

@ -166,7 +166,7 @@ struct TestAddFrom {
} }
SimpleAddFrom(o, e, count); SimpleAddFrom(o, e, count);
AddFrom(o, x, count); AddFrom(o, x, count, /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__); __LINE__);
@ -199,7 +199,7 @@ struct TestMulByConstAndAdd {
T constant = Random<T>(rng); T constant = Random<T>(rng);
SimpleMulByConstAndAdd(constant, o, e, count); SimpleMulByConstAndAdd(constant, o, e, count);
MulByConstAndAdd(constant, o, x, count); MulByConstAndAdd(constant, o, x, count, /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__); __LINE__);
@ -229,7 +229,7 @@ struct TestMulByConst {
T constant = Random<T>(rng); T constant = Random<T>(rng);
SimpleMulByConst(constant, e, count); SimpleMulByConst(constant, e, count);
MulByConst(constant, x, count); MulByConst(constant, x, count, /*worker=*/0);
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
__LINE__); __LINE__);
@ -259,7 +259,7 @@ struct TestSoftmax {
} }
SimpleSoftmax(e, count); SimpleSoftmax(e, count);
Softmax(x, count); Softmax(x, count, /*worker=*/0);
T sum = 0.0f; T sum = 0.0f;
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
@ -454,7 +454,7 @@ void TestRMSNorm(hwy::RandomState& rng) {
} }
ScalarRMSNorm(vec, weight, expected, kSize); ScalarRMSNorm(vec, weight, expected, kSize);
RMSNorm(vec, weight, 0, actual, kSize); RMSNorm(vec, weight, 0, actual, kSize, /*worker=*/0);
for (size_t i = 0; i < kSize; i++) { for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]); const float e = hwy::ConvertScalarTo<float>(expected[i]);
@ -584,7 +584,7 @@ void TestSampleTopK() {
std::vector<float> logits(kSize); std::vector<float> logits(kSize);
// Create a vector going from -100 to -100+51=49 and take Softmax. // Create a vector going from -100 to -100+51=49 and take Softmax.
std::iota(logits.begin(), logits.end(), -100.0f); std::iota(logits.begin(), logits.end(), -100.0f);
Softmax(logits.data(), kSize); Softmax(logits.data(), kSize, /*worker=*/0);
std::mt19937 gen; std::mt19937 gen;
gen.seed(0x12345678); gen.seed(0x12345678);
float temperature = 1.0f; float temperature = 1.0f;
@ -600,7 +600,7 @@ void TestSampleTopK() {
EXPECT_EQ(sample, 50); // Last even index. EXPECT_EQ(sample, 50); // Last even index.
// Reset the logits to a positive, increasing sequence and take Softmax. // Reset the logits to a positive, increasing sequence and take Softmax.
std::iota(logits.begin(), logits.end(), 1.0f); std::iota(logits.begin(), logits.end(), 1.0f);
Softmax(logits.data(), kSize); Softmax(logits.data(), kSize, /*worker=*/0);
// Sample from the top 3, expect one of the top 3 even indices. // Sample from the top 3, expect one of the top 3 even indices.
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature, sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,

View File

@ -355,6 +355,17 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx,
}); });
} }
// As above, but for lightweight tasks. Uses only one pool.
template <class Func>
void SmallParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx,
const Func& func) {
const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage();
pools.Pool(pkg_idx).Run(0, num_tasks, [&](uint64_t task, size_t thread) {
func(task, pkg_base + thread);
});
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_