From cbe24eac51c089f6a4f126f62683489c0fe794f0 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 5 Sep 2025 07:23:33 -0700 Subject: [PATCH] 1.15x speedup: parallel sampling, enabled by new RNG Also pass pos to SampleFunc, for seeding the RNG. PiperOrigin-RevId: 803453518 --- evals/cross_entropy.cc | 12 +++--- gemma/activations.h | 3 ++ gemma/gemma.cc | 85 +++++++++++++++++++++++++++--------------- gemma/gemma_args.h | 9 +++-- 4 files changed, 68 insertions(+), 41 deletions(-) diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index b7abb10..49acb50 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -99,14 +99,15 @@ HWY_EXPORT(CallSoftmax); float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, const std::vector& prompt, KVCache& kv_cache, MatMulEnv& env, int verbosity) { - const StreamFunc stream_token = [](int, float) { return true; }; + const BatchStreamFunc stream_token = [](size_t, size_t, int, float) { + return true; + }; const int vocab_size = gemma.Config().vocab_size; float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s) - size_t pos = 1; - const SampleFunc sample_token = [&](size_t qi, - Logits logits) -> TokenAndProb { + const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits, + size_t /*worker*/) -> TokenAndProb { // input is logits, not yet probabilities HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler); // We are called for each token, but pos starts at 1. Clamping @@ -128,7 +129,6 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, cross_entropy / std::log(2.0) / (pos + 1)); } - ++pos; return TokenAndProb{.token = token, .prob = prob}; }; @@ -138,7 +138,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, .max_generated_tokens = max_generated_tokens - 1, .temperature = 0.0f, .verbosity = verbosity, - .stream_token = stream_token, + .batch_stream_token = stream_token, .sample_func = sample_token, }; TimingInfo timing_info; diff --git a/gemma/activations.h b/gemma/activations.h index 67e1eba..21e5e58 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -132,6 +132,7 @@ struct Activations { x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)), logits( MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)), + sampled(MatFactory("sampled", batch_size, 3, ctx.allocator)), pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, config.model_dim, ctx.allocator)), @@ -164,6 +165,7 @@ struct Activations { x.OverrideRows(batch_size); x_bf.OverrideRows(batch_size); logits.OverrideRows(batch_size); + sampled.OverrideRows(batch_size); pre_ffw_rms_out.OverrideRows(batch_size); C1.OverrideRows(batch_size); @@ -178,6 +180,7 @@ struct Activations { MatStorageT x; // input MatStorageT x_bf; // output of final RMSNorm, input to EmbeddingMatmul MatStorageT logits; + MatStorageT sampled; // batch_size x 3 (padded) // Gated FFW MatStorageT pre_ffw_rms_out; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 62288ff..785bd87 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -377,14 +377,13 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, // Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent // `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the // query is at the end of its sequence. -static void StreamAndUpdateEOS(const size_t qi, int token, const float prob, - const ModelConfig& config, +static void StreamAndUpdateEOS(const size_t qi, size_t pos, int token, + const float prob, const ModelConfig& config, const RuntimeConfig& runtime_config, - QBatch& qbatch, bool pos_plus_1, bool update_pos, + QBatch& qbatch, bool update_pos, hwy::BitSet4096<>& non_eos) { HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called. - const size_t pos = qbatch.Pos(qi) + (pos_plus_1 ? 1 : 0); if (HWY_UNLIKELY( !runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) { // User decided to stop: set token to primary EOS to trigger IsEOS below. @@ -402,11 +401,13 @@ static void StreamAndUpdateEOS(const size_t qi, int token, const float prob, // Must be called after Transformer: either after prefill, or during decode. // Computes logits, samples and streams the token. -static void SampleAndStream( - const ModelConfig& config, const RuntimeConfig& runtime_config, - const WeightsPtrs& weights, const SampleFunc& sample_token, - Activations& activations, QBatch& qbatch, bool update_pos, MatMulEnv& env, - hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) { +static void SampleAndStream(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, + const SampleFunc& sample_token, + Activations& activations, QBatch& qbatch, + MatMulEnv& env, hwy::BitSet4096<>& non_eos, + TimingInfo& timing_info) { HWY_DASSERT(qbatch.Size() == activations.x.Rows()); RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf, @@ -429,16 +430,33 @@ static void SampleAndStream( timing_info.NotifyGenerated(non_eos.Count()); - // TODO: parallelize - non_eos.Foreach([&](size_t qi) { - const TokenAndProb tp = sample_token(qi, activations.logits.RowSpan(qi)); + ParallelFor( + ParallelismStrategy::kFlat, qbatch.Size(), env.ctx, + /*cluster_idx=*/0, [&](size_t qi, size_t worker) { + if (!non_eos.Get(qi)) return; - // We streamed all prefill tokens, but pos is still one behind because we - // started generation at pos = prompt.size() - 1. We want the pos argument - // to match the number of calls to `StreamToken`, as expected by the caller. - const bool pos_plus_1 = true; - StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch, - pos_plus_1, update_pos, non_eos); + // We streamed all prefill tokens, but pos is still one behind + // because we started generation at pos = prompt.size() - 1. + // We want the pos argument to match the number of calls to + // `StreamToken`, as expected by the caller. + const size_t pos = qbatch.Pos(qi) + 1; + + const TokenAndProb tp = + sample_token(qi, pos, activations.logits.RowSpan(qi), worker); + // `sampled` is padded, which prevents false sharing. + activations.sampled.Row(qi)[0] = static_cast(pos); + activations.sampled.Row(qi)[1] = static_cast(tp.token); + activations.sampled.Row(qi)[2] = hwy::BitCastScalar(tp.prob); + }); + + // Sequentially, because `StreamToken` is not yet thread-safe. + non_eos.Foreach([&](size_t qi) { + const size_t pos = activations.sampled.Row(qi)[0]; + const int token = static_cast(activations.sampled.Row(qi)[1]); + const float prob = + hwy::BitCastScalar(activations.sampled.Row(qi)[2]); + StreamAndUpdateEOS(qi, pos, token, prob, config, runtime_config, qbatch, + /*update_pos=*/true, non_eos); }); } @@ -448,21 +466,25 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, // If user provided a sample_func, use it. if (runtime_config.sample_func) return runtime_config.sample_func; - static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1"); - const size_t worker = 0; // TODO: parallelize + static const auto zone_top1 = ctx.profiler.AddZone("Gen.Sample Top1"); + static const auto zone_topK = ctx.profiler.AddZone("Gen.Sample general"); // Fast path for top-1 with no accept_token. if (runtime_config.top_k == 1 && !runtime_config.accept_token) { - return [&](size_t /*qi*/, Logits logits) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE3(ctx.profiler, worker, zone); - return Top1OfSoftmax(logits); - }; + return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker) + HWY_ATTR -> TokenAndProb { + PROFILER_ZONE3(ctx.profiler, worker, zone_top1); + return Top1OfSoftmax(logits); + }; } // General case: Softmax with top-k sampling. - return [&](size_t qi, Logits logits) HWY_ATTR -> TokenAndProb { - PROFILER_ZONE("Gen.Sample general"); - RngStream gen(engine, qi); + return [&](size_t qi, size_t pos, Logits logits, + size_t worker) HWY_ATTR -> TokenAndProb { + PROFILER_ZONE3(ctx.profiler, worker, zone_topK); + // We want a different sequence for each batch element and position. + const uint64_t stream = (static_cast(qi) << 32) | pos; + RngStream gen(engine, stream); return FusedSoftmaxAndSampleTopK( logits, runtime_config.top_k, gen, runtime_config.temperature, runtime_config.accept_token, ctx.profiler, worker); @@ -524,12 +546,13 @@ static void GenerateT(const ModelConfig& config, // Stream the last prompt token from each query, fill activations.gen_tokens. for (size_t qi = 0; qi < qbatch.Size(); ++qi) { const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); - const bool pos_plus_1 = false; // during prefill, pos is still correct. + + const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct. // In autoregressive mode, we have not prefilled the last token, so do // not advance. const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); - StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, - runtime_config, qbatch, pos_plus_1, update_pos, non_eos); + StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, + config, runtime_config, qbatch, update_pos, non_eos); } size_t max_gen_steps = runtime_config.max_generated_tokens; @@ -546,7 +569,7 @@ static void GenerateT(const ModelConfig& config, for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { Transformer(config, runtime_config, weights, activations, qbatch, env); SampleAndStream(config, runtime_config, weights, sample_token, activations, - qbatch, /*update_pos=*/true, env, non_eos, timing_info); + qbatch, env, non_eos, timing_info); } timing_info.NotifyGenerateDone(); } diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 59e3a6c..b2d19ff 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -89,10 +89,10 @@ using BatchStreamFunc = std::function; // If not empty, AcceptFunc is called with token. It should return false for // tokens you don't want to generate and true for tokens you want to generate. using AcceptFunc = std::function; -// If not empty, SampleFunc is called with the query_idx and logits for the -// next token, which it may modify/overwrite. It returns the next generated -// token together with its probability. -using SampleFunc = std::function; +// If not empty, SampleFunc is called concurrently from worker thread(s) with +// query_idx, pos, logits for the next token (which it may modify/overwrite), +// and worker. It returns the next generated token and its probability. +using SampleFunc = std::function; // If not empty, LayersOutputFunc is called for layer outputs, specified with: // - index of query within containing batch (if any); zero otherwise. // - position in the tokens sequence @@ -115,6 +115,7 @@ using ActivationsObserverFunc = struct RuntimeConfig { // If non-null, `batch_stream_token` is called for each token in the batch, // otherwise `stream_token`. `query_idx` is absolute, not batch-relative. + // This is called sequentially from the main thread. bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { PROFILER_ZONE("Gen.StreamToken"); if (batch_stream_token) {