From d1638587f02f547aa8ad76f36481996e05f3ae77 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 30 Jul 2025 00:54:55 -0700 Subject: [PATCH] 1.14x batch decode speedup: parallelize RMSNorm ops Activations was over-parallelized, use single pool instead. Also improve profiler zone annotations, pass through worker args (for tracking concurrency), now non-optional. PiperOrigin-RevId: 788790976 --- BUILD.bazel | 1 + evals/cross_entropy.cc | 2 +- gemma/activations.h | 3 +- gemma/attention.cc | 12 +++++--- gemma/gemma-inl.h | 42 +++++++++++++------------ gemma/gemma.cc | 29 +++++++++-------- gemma/gemma_args.h | 2 ++ gemma/vit.cc | 28 ++++++++--------- ops/matmul-inl.h | 36 +++++++++++----------- ops/ops-inl.h | 70 ++++++++++++++++++++++++------------------ ops/ops_test.cc | 14 ++++----- util/threading.h | 11 +++++++ 12 files changed, 142 insertions(+), 108 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 4421189..2628bc3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -474,6 +474,7 @@ cc_library( ":matmul", "//io", "@highway//:hwy", + "@highway//:profiler", ], ) diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 320967c..09c3a42 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -85,7 +85,7 @@ namespace gcpp { namespace HWY_NAMESPACE { void CallSoftmax(float* HWY_RESTRICT logits, size_t vocab_size) { - Softmax(logits, vocab_size); + Softmax(logits, vocab_size, /*worker=*/0); } } // namespace HWY_NAMESPACE diff --git a/gemma/activations.h b/gemma/activations.h index ccb1a59..b222bd9 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -29,7 +29,6 @@ #include "util/allocator.h" // Allocator #include "util/basics.h" // BF16 #include "util/mat.h" // MatStorageT -#include "hwy/profiler.h" namespace gcpp { @@ -182,8 +181,8 @@ struct Activations { // Note that BindC on any MatMul output considerably slows down Prefill. } + // Negligible CPU time. void SetBatchSize(size_t batch_size) { - PROFILER_ZONE("SetBatchSize"); x.OverrideRows(batch_size); logits.OverrideRows(batch_size); diff --git a/gemma/attention.cc b/gemma/attention.cc index 936db08..74ea77a 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -150,7 +150,7 @@ void SingleDotSoftmaxWeightedSum( // SoftMax with optional SoftCap yields "probabilities" in att. const size_t att_len = HWY_MIN(last_pos + 1, seq_len); MaybeLogitsSoftCap(att_cap, att, att_len, worker); - Softmax(att, att_len, /*temperature=*/1.0f, worker); + Softmax(att, att_len, worker, /*temperature=*/1.0f); WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out, worker); @@ -168,7 +168,6 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, AttentionActivations& activations, QBatch& qbatch, NestedPools& pools) { - PROFILER_ZONE("Gen.Attention.DotSoftmax.misc"); static const uint32_t HWY_MAYBE_UNUSED zone_id_par = PROFILER_ADD_ZONE("Gen.Attention.DotSoftmax.par"); @@ -227,8 +226,13 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, layer, activations, att, att_out, worker); }; - ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, pools, - /*pkg_idx=*/0, func); + { + PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); + const size_t pkg_idx = 0; + // Full parallelism is helpful, SmallParallelFor is insufficient. + ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, + pools, pkg_idx, func); + } } // Different functions use different naming conventions for the number of diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 7aa3318..e3f9a19 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -66,12 +66,13 @@ void Activation(ActivationType activation, T* HWY_RESTRICT c1, template void ActivationBatched(ActivationType activation, Mat& c1, NestedPools& pools) { using T = typename Mat::T; - ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0, - [&](uint64_t task, size_t worker) { - // Cast to correct type so type deduction works. - Activation(activation, c1.Row(task), - static_cast(nullptr), c1.Cols(), worker); - }); + const size_t pkg_idx = 0; + SmallParallelFor( + c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) { + // Cast to correct type so type deduction works. + Activation(activation, c1.Row(task), static_cast(nullptr), + c1.Cols(), worker); + }); } template @@ -79,18 +80,19 @@ HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat& c1, const Mat* c2, NestedPools& pools) { using T = typename Mat::T; HWY_DASSERT(c1.SameShape(*c2)); + const size_t pkg_idx = 0; if (c2 && c2->HasPtr()) { - ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0, - [&](uint64_t task, size_t worker) { - Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), - worker); - }); + SmallParallelFor(c1.Rows(), pools, pkg_idx, + [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), c2->Row(task), + c1.Cols(), worker); + }); } else { // No multiplier - ParallelFor(c1.Rows(), pools, /*pkg_idx=*/0, - [&](uint64_t task, size_t worker) { - Activation(activation, c1.Row(task), - static_cast(nullptr), c1.Cols(), worker); - }); + SmallParallelFor( + c1.Rows(), pools, pkg_idx, [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), static_cast(nullptr), + c1.Cols(), worker); + }); } } @@ -98,17 +100,17 @@ template HWY_NOINLINE void ResidualConnection(const MatPtrT& other, MatPtrT& HWY_RESTRICT x, const LayerWeights& layer, - bool is_attention) { + bool is_attention, ThreadingContext& ctx) { // ResidualType::Add - AddFromBatched(other, x); + AddFromBatched(other, x, ctx); } template void PostNorm(PostNormType post_norm, const MatPtr& weights, - MatPtrT& inout) { + MatPtrT& inout, ThreadingContext& ctx) { HWY_DASSERT(weights.Rows() == 1); if (post_norm == PostNormType::Scale) { - RMSNormInplaceBatched(weights, inout); + RMSNormInplaceBatched(weights, inout, ctx); } } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 109ccad..66a8433 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -92,19 +92,19 @@ static HWY_NOINLINE void TransformerLayer(const size_t num_tokens, const LayerConfig& layer_config = layer.layer_config; RMSNormBatched(activations.x, layer.pre_attention_norm_scale, - activations.attention.pre_att_rms_out); + activations.attention.pre_att_rms_out, env.ctx); Attention(layer_config.type, num_tokens, layer_idx, layer, activations, qbatch, env); PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, - activations.attention.att_sums); + activations.attention.att_sums, env.ctx); ResidualConnection(activations.attention.att_sums, activations.x, layer, - /*is_attention=*/true); + /*is_attention=*/true, env.ctx); RMSNormBatched(activations.x, layer.pre_ffw_norm_scale, - activations.pre_ffw_rms_out); + activations.pre_ffw_rms_out, env.ctx); if (layer_config.type == LayerAttentionType::kVit) { FFWVit(layer, activations, env); @@ -113,10 +113,10 @@ static HWY_NOINLINE void TransformerLayer(const size_t num_tokens, } PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale, - activations.ffw_out); + activations.ffw_out, env.ctx); ResidualConnection(activations.ffw_out, activations.x, layer, - /*is_attention=*/false); + /*is_attention=*/false, env.ctx); } // Returns the scale value to use for the embedding (basically sqrt model_dim). @@ -158,6 +158,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, const size_t model_dim = model_config.model_dim; const float emb_scaling = EmbeddingScaling(model_dim); + const size_t worker = 0; // Not yet parallelized. HWY_DASSERT(token >= 0); HWY_DASSERT(token < static_cast(model_config.vocab_size)); @@ -173,7 +174,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, const hn::ScalableTag df; DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi), model_dim); - MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim); + MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, worker); }); if (model_config.absolute_pe) { @@ -302,6 +303,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config, } } + // TODO: parallelize? for (size_t qi = 0; qi < qbatch.Size(); ++qi) { EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi), /*pos_in_prompt=*/0, config, weights, activations.x); @@ -328,7 +330,7 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, Activations& activations, QBatch& qbatch, MatMulEnv& env, hwy::BitSet4096<>& non_eos) { - PROFILER_ZONE("Gen.Prefill"); + PROFILER_ZONE("Gen.PrefillQ"); for (size_t qi = 0; qi < qbatch.Size(); ++qi) { non_eos.Set(qi); @@ -400,7 +402,7 @@ static void DecodeStepT(const ModelConfig& config, Transformer(config, runtime_config, weights, activations, qbatch, env); - RMSNormInplaceBatched(weights.final_norm_scale, activations.x); + RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx); if (HWY_UNLIKELY(runtime_config.activations_observer)) { runtime_config.activations_observer( @@ -414,9 +416,10 @@ static void DecodeStepT(const ModelConfig& config, /*add=*/nullptr, env, activations.logits); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); + const size_t worker = 0; // TODO: parallelize non_eos.Foreach([&](size_t qi) { float* HWY_RESTRICT logits = activations.logits.Row(qi); - MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size); + MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size, worker); const TokenAndProb tp = sample_token(logits, config.vocab_size); timing_info.NotifyGenerated(); @@ -430,10 +433,12 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) { // If user provided a sample_func, use it. if (runtime_config.sample_func) return runtime_config.sample_func; + const size_t worker = 0; // TODO: parallelize + // Fast path for top-1 with no accept_token. if (runtime_config.top_k == 1 && !runtime_config.accept_token) { return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE("Gen.Sample Top1"); + PROFILER_ZONE2(worker, "Gen.Sample Top1"); return Top1OfSoftmax(logits, vocab_size); }; } @@ -444,7 +449,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) { PROFILER_ZONE("Gen.Sample general"); return FusedSoftmaxAndSampleTopK( logits, runtime_config.top_k, vocab_size, *runtime_config.gen, - runtime_config.temperature, runtime_config.accept_token); + runtime_config.temperature, runtime_config.accept_token, worker); }; } diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index df950a6..70268c7 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -32,6 +32,7 @@ #include "util/mat.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" // HWY_ABORT +#include "hwy/profiler.h" namespace gcpp { @@ -116,6 +117,7 @@ struct RuntimeConfig { // If non-null, `batch_stream_token` is called for each token in the batch, // otherwise `stream_token`. `query_idx` is absolute, not batch-relative. bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { + PROFILER_ZONE("Gen.StreamToken"); if (batch_stream_token) { return batch_stream_token(query_idx, pos, token, prob); } diff --git a/gemma/vit.cc b/gemma/vit.cc index 82838f2..3549f85 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -90,12 +90,12 @@ class VitAttention { ZeroInit(activations_.attention.att_out); for (size_t head = 0; head < heads; ++head) { - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { const size_t token = task; float* HWY_RESTRICT q = activations_.attention.q.Row(token) + head * 3 * qkv_dim; // TODO: shift to MatMul with A.scale once MatMul is confirmed working - MulByConst(query_scale, q, qkv_dim); + MulByConst(query_scale, q, qkv_dim, worker); hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); }); @@ -109,19 +109,19 @@ class VitAttention { // this produces C, a (num_tokens_, seq_len) matrix of dot products CallMatMul(Q, K, nullptr, env_, C); - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { float* HWY_RESTRICT c = C.Row(task); - Softmax(c, C.Cols()); + Softmax(c, C.Cols(), worker); }); - pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t worker) HWY_ATTR { size_t token = task; float* HWY_RESTRICT att_out = activations_.attention.att_out.Row(token) + head * qkv_dim; for (size_t i = 0; i < seq_len; ++i) { float* HWY_RESTRICT v = activations_.attention.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); + MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim, worker); } }); } @@ -138,13 +138,13 @@ class VitAttention { // Compute Q.K, softmax, and weighted V. pool_.Run(0, layer_config_.heads * num_tokens_, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + [&](uint64_t task, size_t worker) HWY_ATTR { const size_t head = task % layer_config_.heads; const size_t token = task / layer_config_.heads; // Compute Q.K scores, which are "logits" stored in head_att. float* HWY_RESTRICT q = activations_.attention.q.Row(token) + head * 3 * qkv_dim; - MulByConst(query_scale, q, qkv_dim); + MulByConst(query_scale, q, qkv_dim, worker); float* HWY_RESTRICT head_att = activations_.attention.att.Row(token) + head * seq_len; for (size_t i = 0; i < seq_len; ++i) { @@ -153,7 +153,7 @@ class VitAttention { head_att[i] = Dot(q, k, qkv_dim); // score = q.k } // SoftMax yields "probabilities" in head_att. - Softmax(head_att, seq_len); + Softmax(head_att, seq_len, worker); // Compute weighted sum of v into att_out. float* HWY_RESTRICT att_out = activations_.attention.att_out.Row(token) + head * qkv_dim; @@ -161,7 +161,7 @@ class VitAttention { for (size_t i = 0; i < seq_len; ++i) { float* HWY_RESTRICT v = activations_.attention.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); + MulByConstAndAdd(head_att[i], v, att_out, qkv_dim, worker); } }); } @@ -259,7 +259,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, VitAttention(num_tokens, layer_idx, activations, layer, env)(); // x = out["+sa"] = x + y - AddFromBatched(activations.attention.att_sums, x); + AddFromBatched(activations.attention.att_sums, x, env.ctx); // y = nn.LayerNorm()(x) // y ~ pre_ffw_rms_out @@ -271,7 +271,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, FFWVit(layer, activations, env); // x = out["+mlp"] = x + y - AddFromBatched(activations.ffw_out, x); + AddFromBatched(activations.ffw_out, x, env.ctx); } // Gets the patches of the image and embeds them with the image embedding @@ -303,7 +303,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image, // Add position embeddings. CallUpcastedActivation(&weights.vit_img_pos_embedding, [&](const auto* weights_t) { - AddFromBatched(*weights_t, activations.x); + AddFromBatched(*weights_t, activations.x, env.ctx); }); } @@ -334,7 +334,7 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, // Apply soft embedding norm before input projection. CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0), - vit_model_dim); + vit_model_dim, /*worker=*/0); }); } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 9d67a10..6712da3 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -875,9 +875,6 @@ class MMPerPackage { inner_tasks_(config.InnerTasks()), out_(config.Out()), line_bytes_(args.env->ctx.allocator.LineBytes()) { - static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DecompressA"); - MMZone zone; - zone.MaybeEnter(pkg_idx, zone_id, args_); A_ = DecompressA(A); } @@ -1119,8 +1116,14 @@ class MMPerPackage { const size_t NBF = hn::Lanes(dbf); static_assert(hwy::IsSameEither(), "Can seek"); + static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DecompressA"); + const auto do_range = [&](const IndexRange& range_M, - const IndexRange& range_K) HWY_ATTR { + const IndexRange& range_K, + size_t worker) HWY_ATTR { + MMZone zone; + zone.MaybeEnter(worker, zone_id, args_); + const size_t col0 = range_K.begin(); const size_t cols = range_K.Num(); // Must be a vector multiple, or the last range before row padding, @@ -1141,7 +1144,7 @@ class MMPerPackage { switch (par_a) { case MMParA::kNone: - do_range(all_M, all_K); + do_range(all_M, all_K, /*worker=*/0); break; case MMParA::kK1: case MMParA::kK2: @@ -1154,15 +1157,15 @@ class MMPerPackage { args_.env->parallel.ForNP( all_K, multiple_K, inner_tasks, pkg_idx_, - [&](const IndexRange& range_K, size_t /*worker*/) { - do_range(all_M, range_K); + [&](const IndexRange& range_K, size_t worker) { + do_range(all_M, range_K, worker); }); break; } case MMParA::kM: args_.env->parallel.ForRangeMC( - all_M, pkg_idx_, [&](size_t row_a, size_t /*worker*/) { - do_range(IndexRange(row_a, row_a + 1), all_K); + all_M, pkg_idx_, [&](size_t row_a, size_t worker) { + do_range(IndexRange(row_a, row_a + 1), all_K, worker); }); break; } @@ -1190,12 +1193,9 @@ class MMPerPackage { // First call: generate candidates. if (HWY_UNLIKELY(!autotune.HasCandidates())) { - std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4}; - if (A.Rows() == 1) { - candidates.push_back(MMParA::kNone); - } else { - candidates.push_back(MMParA::kM); - } + const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM; + std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4, + other}; autotune.SetCandidates(candidates); } @@ -1279,7 +1279,8 @@ struct MMImpl { static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const MatPtrT& B, RowPtrs C_rows, const MMArgs& args, const MMConfig& config) { - static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul"); + PROFILER_ZONE("MM.DoMatMul"); + static const uint32_t zone_id = PROFILER_ADD_ZONE("MM.DoMatMul.PerPkg"); // Outermost loop: static NUMA-aware partition of B rows across packages. args.env->parallel.ForPkg( @@ -1353,7 +1354,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, return &per_key; } - PROFILER_ZONE("Matmul.Autotune"); + // From here, CPU time is negligible except DoMatMul. // First call: enumerate all feasible configs. if (HWY_UNLIKELY(!tuner.HasCandidates())) { @@ -1364,7 +1365,6 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, HWY_ASSERT(N <= MMStorage::kMaxN); HWY_ASSERT(N % kNR == 0); - // Negligible CPU time. tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR, kNR, per_key.ranges_np, env.print_config)); } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 0c53805..0806846 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -208,7 +208,7 @@ template HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( const XT* HWY_RESTRICT x, const WT* HWY_RESTRICT weight, size_t w_ofs, OT* HWY_RESTRICT out, const size_t size, - const size_t HWY_MAYBE_UNUSED worker = 0) { + const size_t HWY_MAYBE_UNUSED worker) { PROFILER_ZONE2(worker, "ops.RMSNorm"); namespace hn = hwy::HWY_NAMESPACE; @@ -240,7 +240,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( template HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( const WT* HWY_RESTRICT weight, size_t w_ofs, XT* HWY_RESTRICT inout, - const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) { + const size_t size, const HWY_MAYBE_UNUSED size_t worker) { PROFILER_ZONE2(worker, "ops.RMSNormInplace"); namespace hn = hwy::HWY_NAMESPACE; @@ -527,7 +527,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( template static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( const XT* HWY_RESTRICT x, float* HWY_RESTRICT out, const size_t size, - const HWY_MAYBE_UNUSED size_t worker = 0) { + const HWY_MAYBE_UNUSED size_t worker) { PROFILER_ZONE2(worker, "ops.AddFrom"); namespace hn = hwy::HWY_NAMESPACE; @@ -570,29 +570,36 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( // Simple loops unless/until batch sizes are large enough to parallelize. template void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, - MatPtrT& out) { + MatPtrT& out, ThreadingContext& ctx) { HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Cols() == activations.Cols()); HWY_DASSERT(activations.SameShape(out)); CallUpcasted(&weights, [&](const auto* weights_t) { - for (size_t token_idx = 0; token_idx < activations.Rows(); ++token_idx) { - RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0, - out.Row(token_idx), activations.Cols()); - } + const size_t pkg_idx = 0; + SmallParallelFor(activations.Rows(), ctx.pools, pkg_idx, + [&](uint64_t token_idx, size_t worker) { + RMSNorm(activations.Row(token_idx), + weights_t->PackedScale1(), 0, out.Row(token_idx), + activations.Cols(), worker); + }); }); } template -void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout) { +void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, + ThreadingContext& ctx) { HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Cols() == inout.Cols()); CallUpcasted(&weights, [&](const auto* weights_t) { - for (size_t token_idx = 0; token_idx < inout.Rows(); ++token_idx) { - RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx), - inout.Cols()); - } + const size_t pkg_idx = 0; + SmallParallelFor(inout.Rows(), ctx.pools, pkg_idx, + [&](uint64_t token_idx, size_t worker) { + RMSNormInplace(weights_t->PackedScale1(), 0, + inout.Row(token_idx), inout.Cols(), + worker); + }); }); } @@ -614,18 +621,20 @@ void LayerNormBatched(const MatPtrT& x, const MatPtr& weight, } template -static HWY_INLINE void AddFromBatched(const MatPtrT& x, - MatPtrT& out) { +static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, + ThreadingContext& ctx) { HWY_DASSERT(out.SameShape(x)); - for (size_t token_idx = 0; token_idx < out.Rows(); ++token_idx) { - AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols()); - } + const size_t pkg_idx = 0; + SmallParallelFor( + out.Rows(), ctx.pools, pkg_idx, [&](uint64_t token_idx, size_t worker) { + AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), worker); + }); } template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst( const float c, XT* HWY_RESTRICT x, const size_t size, - const HWY_MAYBE_UNUSED size_t worker = 0) { + const HWY_MAYBE_UNUSED size_t worker) { PROFILER_ZONE2(worker, "ops.MulByConst"); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; @@ -666,7 +675,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst( template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, - const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) { + const size_t size, const HWY_MAYBE_UNUSED size_t worker) { PROFILER_ZONE2(worker, "ops.MulByConstTo"); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; @@ -708,7 +717,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, - const size_t size, const HWY_MAYBE_UNUSED size_t worker = 0) { + const size_t size, const HWY_MAYBE_UNUSED size_t worker) { PROFILER_ZONE2(worker, "ops.MulByConstAndAdd"); namespace hn = hwy::HWY_NAMESPACE; const hn::ScalableTag df; @@ -754,8 +763,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( // See below for a specialized version for top-1 sampling. static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, - float temperature = 1.0f, - const HWY_MAYBE_UNUSED size_t worker = 0) { + const size_t worker, + float temperature = 1.0f) { PROFILER_ZONE2(worker, "ops.Softmax"); HWY_DASSERT(size != 0); @@ -797,7 +806,7 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, const float sum_exp = Sum(d, x, size); // Double-precision reciprocal does not appear to affect the results. const float mul = 1.0f / sum_exp; - MulByConst(mul, x, size); + MulByConst(mul, x, size, worker); } // Note: https://arxiv.org/pdf/2001.04438 proposes to replace the three max / @@ -886,9 +895,9 @@ static HWY_MAYBE_UNUSED TokenAndProb Top1OfSoftmax(float* HWY_RESTRICT x, return TokenAndProb{.token = argmax.token, .prob = prob}; } -static HWY_NOINLINE void LogitsSoftCap( - const float cap, float* HWY_RESTRICT x, const size_t size, - const HWY_MAYBE_UNUSED size_t worker = 0) { +static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, + const size_t size, + const HWY_MAYBE_UNUSED size_t worker) { PROFILER_ZONE2(worker, "ops.LogitsSoftCap"); namespace hn = hwy::HWY_NAMESPACE; @@ -906,7 +915,7 @@ static HWY_NOINLINE void LogitsSoftCap( // Calls LogitsSoftCap if cap != 0.0f. static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( const float cap, float* HWY_RESTRICT x, const size_t size, - const size_t worker = 0) { + const size_t worker) { if (cap != 0.0f) { LogitsSoftCap(cap, x, size, worker); } @@ -991,7 +1000,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( template HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( const float* HWY_RESTRICT logits, size_t k, size_t vocab_size, - std::mt19937& gen, float temperature, TAcceptToken& accept_token) { + std::mt19937& gen, float temperature, TAcceptToken& accept_token, + size_t worker) { // Softmax and sample top-K is equivalent to taking the top-K logits and // sampling from the softmax of the top-K logits. The latter is faster as it // avoids computing the softmax of all logits. @@ -1005,7 +1015,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( } size_t mask = token_logits.size(); - Softmax(topk_logits.data(), mask, temperature); + Softmax(topk_logits.data(), mask, worker, temperature); auto distribution = std::discrete_distribution( std::begin(topk_logits), std::begin(topk_logits) + mask); int topk_sampled_index = distribution(gen); diff --git a/ops/ops_test.cc b/ops/ops_test.cc index da1e23c..2a51839 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -166,7 +166,7 @@ struct TestAddFrom { } SimpleAddFrom(o, e, count); - AddFrom(o, x, count); + AddFrom(o, x, count, /*worker=*/0); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -199,7 +199,7 @@ struct TestMulByConstAndAdd { T constant = Random(rng); SimpleMulByConstAndAdd(constant, o, e, count); - MulByConstAndAdd(constant, o, x, count); + MulByConstAndAdd(constant, o, x, count, /*worker=*/0); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -229,7 +229,7 @@ struct TestMulByConst { T constant = Random(rng); SimpleMulByConst(constant, e, count); - MulByConst(constant, x, count); + MulByConst(constant, x, count, /*worker=*/0); hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, __LINE__); @@ -259,7 +259,7 @@ struct TestSoftmax { } SimpleSoftmax(e, count); - Softmax(x, count); + Softmax(x, count, /*worker=*/0); T sum = 0.0f; for (size_t i = 0; i < count; ++i) { @@ -454,7 +454,7 @@ void TestRMSNorm(hwy::RandomState& rng) { } ScalarRMSNorm(vec, weight, expected, kSize); - RMSNorm(vec, weight, 0, actual, kSize); + RMSNorm(vec, weight, 0, actual, kSize, /*worker=*/0); for (size_t i = 0; i < kSize; i++) { const float e = hwy::ConvertScalarTo(expected[i]); @@ -584,7 +584,7 @@ void TestSampleTopK() { std::vector logits(kSize); // Create a vector going from -100 to -100+51=49 and take Softmax. std::iota(logits.begin(), logits.end(), -100.0f); - Softmax(logits.data(), kSize); + Softmax(logits.data(), kSize, /*worker=*/0); std::mt19937 gen; gen.seed(0x12345678); float temperature = 1.0f; @@ -600,7 +600,7 @@ void TestSampleTopK() { EXPECT_EQ(sample, 50); // Last even index. // Reset the logits to a positive, increasing sequence and take Softmax. std::iota(logits.begin(), logits.end(), 1.0f); - Softmax(logits.data(), kSize); + Softmax(logits.data(), kSize, /*worker=*/0); // Sample from the top 3, expect one of the top 3 even indices. for (int i = 0; i < 100; ++i) { sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature, diff --git a/util/threading.h b/util/threading.h index efb536f..8d2c013 100644 --- a/util/threading.h +++ b/util/threading.h @@ -355,6 +355,17 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx, }); } +// As above, but for lightweight tasks. Uses only one pool. +template +void SmallParallelFor(size_t num_tasks, NestedPools& pools, size_t pkg_idx, + const Func& func) { + const size_t pkg_base = pkg_idx * pools.MaxWorkersPerPackage(); + + pools.Pool(pkg_idx).Run(0, num_tasks, [&](uint64_t task, size_t thread) { + func(task, pkg_base + thread); + }); +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_