1.15x speedup: parallel sampling, enabled by new RNG

Also pass pos to SampleFunc, for seeding the RNG.

PiperOrigin-RevId: 803453518
This commit is contained in:
Jan Wassenberg 2025-09-05 07:23:33 -07:00 committed by Copybara-Service
parent ad7d7a2713
commit cbe24eac51
4 changed files with 68 additions and 41 deletions

View File

@ -99,14 +99,15 @@ HWY_EXPORT(CallSoftmax);
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens, float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache, const std::vector<int>& prompt, KVCache& kv_cache,
MatMulEnv& env, int verbosity) { MatMulEnv& env, int verbosity) {
const StreamFunc stream_token = [](int, float) { return true; }; const BatchStreamFunc stream_token = [](size_t, size_t, int, float) {
return true;
};
const int vocab_size = gemma.Config().vocab_size; const int vocab_size = gemma.Config().vocab_size;
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s) float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
size_t pos = 1;
const SampleFunc sample_token = [&](size_t qi, const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits,
Logits logits) -> TokenAndProb { size_t /*worker*/) -> TokenAndProb {
// input is logits, not yet probabilities // input is logits, not yet probabilities
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler); HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler);
// We are called for each token, but pos starts at 1. Clamping // We are called for each token, but pos starts at 1. Clamping
@ -128,7 +129,6 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1, printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
cross_entropy / std::log(2.0) / (pos + 1)); cross_entropy / std::log(2.0) / (pos + 1));
} }
++pos;
return TokenAndProb{.token = token, .prob = prob}; return TokenAndProb{.token = token, .prob = prob};
}; };
@ -138,7 +138,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
.max_generated_tokens = max_generated_tokens - 1, .max_generated_tokens = max_generated_tokens - 1,
.temperature = 0.0f, .temperature = 0.0f,
.verbosity = verbosity, .verbosity = verbosity,
.stream_token = stream_token, .batch_stream_token = stream_token,
.sample_func = sample_token, .sample_func = sample_token,
}; };
TimingInfo timing_info; TimingInfo timing_info;

View File

@ -132,6 +132,7 @@ struct Activations {
x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)), x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)),
logits( logits(
MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)), MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)),
sampled(MatFactory("sampled", batch_size, 3, ctx.allocator)),
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
config.model_dim, ctx.allocator)), config.model_dim, ctx.allocator)),
@ -164,6 +165,7 @@ struct Activations {
x.OverrideRows(batch_size); x.OverrideRows(batch_size);
x_bf.OverrideRows(batch_size); x_bf.OverrideRows(batch_size);
logits.OverrideRows(batch_size); logits.OverrideRows(batch_size);
sampled.OverrideRows(batch_size);
pre_ffw_rms_out.OverrideRows(batch_size); pre_ffw_rms_out.OverrideRows(batch_size);
C1.OverrideRows(batch_size); C1.OverrideRows(batch_size);
@ -178,6 +180,7 @@ struct Activations {
MatStorageT<float> x; // input MatStorageT<float> x; // input
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
MatStorageT<float> logits; MatStorageT<float> logits;
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
// Gated FFW // Gated FFW
MatStorageT<BF16> pre_ffw_rms_out; MatStorageT<BF16> pre_ffw_rms_out;

View File

@ -377,14 +377,13 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent // Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
// `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the // `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the
// query is at the end of its sequence. // query is at the end of its sequence.
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob, static void StreamAndUpdateEOS(const size_t qi, size_t pos, int token,
const ModelConfig& config, const float prob, const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
QBatch& qbatch, bool pos_plus_1, bool update_pos, QBatch& qbatch, bool update_pos,
hwy::BitSet4096<>& non_eos) { hwy::BitSet4096<>& non_eos) {
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called. HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
const size_t pos = qbatch.Pos(qi) + (pos_plus_1 ? 1 : 0);
if (HWY_UNLIKELY( if (HWY_UNLIKELY(
!runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) { !runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) {
// User decided to stop: set token to primary EOS to trigger IsEOS below. // User decided to stop: set token to primary EOS to trigger IsEOS below.
@ -402,11 +401,13 @@ static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
// Must be called after Transformer: either after prefill, or during decode. // Must be called after Transformer: either after prefill, or during decode.
// Computes logits, samples and streams the token. // Computes logits, samples and streams the token.
static void SampleAndStream( static void SampleAndStream(const ModelConfig& config,
const ModelConfig& config, const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const WeightsPtrs& weights, const SampleFunc& sample_token, const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch, bool update_pos, MatMulEnv& env, const SampleFunc& sample_token,
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) { Activations& activations, QBatch& qbatch,
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
TimingInfo& timing_info) {
HWY_DASSERT(qbatch.Size() == activations.x.Rows()); HWY_DASSERT(qbatch.Size() == activations.x.Rows());
RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf, RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf,
@ -429,16 +430,33 @@ static void SampleAndStream(
timing_info.NotifyGenerated(non_eos.Count()); timing_info.NotifyGenerated(non_eos.Count());
// TODO: parallelize ParallelFor(
non_eos.Foreach([&](size_t qi) { ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
const TokenAndProb tp = sample_token(qi, activations.logits.RowSpan(qi)); /*cluster_idx=*/0, [&](size_t qi, size_t worker) {
if (!non_eos.Get(qi)) return;
// We streamed all prefill tokens, but pos is still one behind because we // We streamed all prefill tokens, but pos is still one behind
// started generation at pos = prompt.size() - 1. We want the pos argument // because we started generation at pos = prompt.size() - 1.
// to match the number of calls to `StreamToken`, as expected by the caller. // We want the pos argument to match the number of calls to
const bool pos_plus_1 = true; // `StreamToken`, as expected by the caller.
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch, const size_t pos = qbatch.Pos(qi) + 1;
pos_plus_1, update_pos, non_eos);
const TokenAndProb tp =
sample_token(qi, pos, activations.logits.RowSpan(qi), worker);
// `sampled` is padded, which prevents false sharing.
activations.sampled.Row(qi)[0] = static_cast<uint32_t>(pos);
activations.sampled.Row(qi)[1] = static_cast<uint32_t>(tp.token);
activations.sampled.Row(qi)[2] = hwy::BitCastScalar<uint32_t>(tp.prob);
});
// Sequentially, because `StreamToken` is not yet thread-safe.
non_eos.Foreach([&](size_t qi) {
const size_t pos = activations.sampled.Row(qi)[0];
const int token = static_cast<int>(activations.sampled.Row(qi)[1]);
const float prob =
hwy::BitCastScalar<float>(activations.sampled.Row(qi)[2]);
StreamAndUpdateEOS(qi, pos, token, prob, config, runtime_config, qbatch,
/*update_pos=*/true, non_eos);
}); });
} }
@ -448,21 +466,25 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
// If user provided a sample_func, use it. // If user provided a sample_func, use it.
if (runtime_config.sample_func) return runtime_config.sample_func; if (runtime_config.sample_func) return runtime_config.sample_func;
static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1"); static const auto zone_top1 = ctx.profiler.AddZone("Gen.Sample Top1");
const size_t worker = 0; // TODO: parallelize static const auto zone_topK = ctx.profiler.AddZone("Gen.Sample general");
// Fast path for top-1 with no accept_token. // Fast path for top-1 with no accept_token.
if (runtime_config.top_k == 1 && !runtime_config.accept_token) { if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [&](size_t /*qi*/, Logits logits) HWY_ATTR -> TokenAndProb { return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker)
PROFILER_ZONE3(ctx.profiler, worker, zone); HWY_ATTR -> TokenAndProb {
return Top1OfSoftmax(logits); PROFILER_ZONE3(ctx.profiler, worker, zone_top1);
}; return Top1OfSoftmax(logits);
};
} }
// General case: Softmax with top-k sampling. // General case: Softmax with top-k sampling.
return [&](size_t qi, Logits logits) HWY_ATTR -> TokenAndProb { return [&](size_t qi, size_t pos, Logits logits,
PROFILER_ZONE("Gen.Sample general"); size_t worker) HWY_ATTR -> TokenAndProb {
RngStream gen(engine, qi); PROFILER_ZONE3(ctx.profiler, worker, zone_topK);
// We want a different sequence for each batch element and position.
const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos;
RngStream gen(engine, stream);
return FusedSoftmaxAndSampleTopK( return FusedSoftmaxAndSampleTopK(
logits, runtime_config.top_k, gen, runtime_config.temperature, logits, runtime_config.top_k, gen, runtime_config.temperature,
runtime_config.accept_token, ctx.profiler, worker); runtime_config.accept_token, ctx.profiler, worker);
@ -524,12 +546,13 @@ static void GenerateT(const ModelConfig& config,
// Stream the last prompt token from each query, fill activations.gen_tokens. // Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
const bool pos_plus_1 = false; // during prefill, pos is still correct.
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
// In autoregressive mode, we have not prefilled the last token, so do // In autoregressive mode, we have not prefilled the last token, so do
// not advance. // not advance.
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f,
runtime_config, qbatch, pos_plus_1, update_pos, non_eos); config, runtime_config, qbatch, update_pos, non_eos);
} }
size_t max_gen_steps = runtime_config.max_generated_tokens; size_t max_gen_steps = runtime_config.max_generated_tokens;
@ -546,7 +569,7 @@ static void GenerateT(const ModelConfig& config,
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
Transformer(config, runtime_config, weights, activations, qbatch, env); Transformer(config, runtime_config, weights, activations, qbatch, env);
SampleAndStream(config, runtime_config, weights, sample_token, activations, SampleAndStream(config, runtime_config, weights, sample_token, activations,
qbatch, /*update_pos=*/true, env, non_eos, timing_info); qbatch, env, non_eos, timing_info);
} }
timing_info.NotifyGenerateDone(); timing_info.NotifyGenerateDone();
} }

View File

@ -89,10 +89,10 @@ using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for // If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate. // tokens you don't want to generate and true for tokens you want to generate.
using AcceptFunc = std::function<bool(int, float)>; using AcceptFunc = std::function<bool(int, float)>;
// If not empty, SampleFunc is called with the query_idx and logits for the // If not empty, SampleFunc is called concurrently from worker thread(s) with
// next token, which it may modify/overwrite. It returns the next generated // query_idx, pos, logits for the next token (which it may modify/overwrite),
// token together with its probability. // and worker. It returns the next generated token and its probability.
using SampleFunc = std::function<TokenAndProb(size_t, Logits)>; using SampleFunc = std::function<TokenAndProb(size_t, size_t, Logits, size_t)>;
// If not empty, LayersOutputFunc is called for layer outputs, specified with: // If not empty, LayersOutputFunc is called for layer outputs, specified with:
// - index of query within containing batch (if any); zero otherwise. // - index of query within containing batch (if any); zero otherwise.
// - position in the tokens sequence // - position in the tokens sequence
@ -115,6 +115,7 @@ using ActivationsObserverFunc =
struct RuntimeConfig { struct RuntimeConfig {
// If non-null, `batch_stream_token` is called for each token in the batch, // If non-null, `batch_stream_token` is called for each token in the batch,
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative. // otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
// This is called sequentially from the main thread.
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
PROFILER_ZONE("Gen.StreamToken"); PROFILER_ZONE("Gen.StreamToken");
if (batch_stream_token) { if (batch_stream_token) {