1.15x speedup: parallel sampling, enabled by new RNG

Also pass pos to SampleFunc, for seeding the RNG.

PiperOrigin-RevId: 803453518
This commit is contained in:
Jan Wassenberg 2025-09-05 07:23:33 -07:00 committed by Copybara-Service
parent ad7d7a2713
commit cbe24eac51
4 changed files with 68 additions and 41 deletions

View File

@ -99,14 +99,15 @@ HWY_EXPORT(CallSoftmax);
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
MatMulEnv& env, int verbosity) {
const StreamFunc stream_token = [](int, float) { return true; };
const BatchStreamFunc stream_token = [](size_t, size_t, int, float) {
return true;
};
const int vocab_size = gemma.Config().vocab_size;
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
size_t pos = 1;
const SampleFunc sample_token = [&](size_t qi,
Logits logits) -> TokenAndProb {
const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits,
size_t /*worker*/) -> TokenAndProb {
// input is logits, not yet probabilities
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler);
// We are called for each token, but pos starts at 1. Clamping
@ -128,7 +129,6 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
cross_entropy / std::log(2.0) / (pos + 1));
}
++pos;
return TokenAndProb{.token = token, .prob = prob};
};
@ -138,7 +138,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
.max_generated_tokens = max_generated_tokens - 1,
.temperature = 0.0f,
.verbosity = verbosity,
.stream_token = stream_token,
.batch_stream_token = stream_token,
.sample_func = sample_token,
};
TimingInfo timing_info;

View File

@ -132,6 +132,7 @@ struct Activations {
x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)),
logits(
MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)),
sampled(MatFactory("sampled", batch_size, 3, ctx.allocator)),
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
config.model_dim, ctx.allocator)),
@ -164,6 +165,7 @@ struct Activations {
x.OverrideRows(batch_size);
x_bf.OverrideRows(batch_size);
logits.OverrideRows(batch_size);
sampled.OverrideRows(batch_size);
pre_ffw_rms_out.OverrideRows(batch_size);
C1.OverrideRows(batch_size);
@ -178,6 +180,7 @@ struct Activations {
MatStorageT<float> x; // input
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
MatStorageT<float> logits;
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
// Gated FFW
MatStorageT<BF16> pre_ffw_rms_out;

View File

@ -377,14 +377,13 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
// `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the
// query is at the end of its sequence.
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
const ModelConfig& config,
static void StreamAndUpdateEOS(const size_t qi, size_t pos, int token,
const float prob, const ModelConfig& config,
const RuntimeConfig& runtime_config,
QBatch& qbatch, bool pos_plus_1, bool update_pos,
QBatch& qbatch, bool update_pos,
hwy::BitSet4096<>& non_eos) {
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
const size_t pos = qbatch.Pos(qi) + (pos_plus_1 ? 1 : 0);
if (HWY_UNLIKELY(
!runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) {
// User decided to stop: set token to primary EOS to trigger IsEOS below.
@ -402,11 +401,13 @@ static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
// Must be called after Transformer: either after prefill, or during decode.
// Computes logits, samples and streams the token.
static void SampleAndStream(
const ModelConfig& config, const RuntimeConfig& runtime_config,
const WeightsPtrs& weights, const SampleFunc& sample_token,
Activations& activations, QBatch& qbatch, bool update_pos, MatMulEnv& env,
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
static void SampleAndStream(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const WeightsPtrs& weights,
const SampleFunc& sample_token,
Activations& activations, QBatch& qbatch,
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
TimingInfo& timing_info) {
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf,
@ -429,16 +430,33 @@ static void SampleAndStream(
timing_info.NotifyGenerated(non_eos.Count());
// TODO: parallelize
non_eos.Foreach([&](size_t qi) {
const TokenAndProb tp = sample_token(qi, activations.logits.RowSpan(qi));
ParallelFor(
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
/*cluster_idx=*/0, [&](size_t qi, size_t worker) {
if (!non_eos.Get(qi)) return;
// We streamed all prefill tokens, but pos is still one behind because we
// started generation at pos = prompt.size() - 1. We want the pos argument
// to match the number of calls to `StreamToken`, as expected by the caller.
const bool pos_plus_1 = true;
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
pos_plus_1, update_pos, non_eos);
// We streamed all prefill tokens, but pos is still one behind
// because we started generation at pos = prompt.size() - 1.
// We want the pos argument to match the number of calls to
// `StreamToken`, as expected by the caller.
const size_t pos = qbatch.Pos(qi) + 1;
const TokenAndProb tp =
sample_token(qi, pos, activations.logits.RowSpan(qi), worker);
// `sampled` is padded, which prevents false sharing.
activations.sampled.Row(qi)[0] = static_cast<uint32_t>(pos);
activations.sampled.Row(qi)[1] = static_cast<uint32_t>(tp.token);
activations.sampled.Row(qi)[2] = hwy::BitCastScalar<uint32_t>(tp.prob);
});
// Sequentially, because `StreamToken` is not yet thread-safe.
non_eos.Foreach([&](size_t qi) {
const size_t pos = activations.sampled.Row(qi)[0];
const int token = static_cast<int>(activations.sampled.Row(qi)[1]);
const float prob =
hwy::BitCastScalar<float>(activations.sampled.Row(qi)[2]);
StreamAndUpdateEOS(qi, pos, token, prob, config, runtime_config, qbatch,
/*update_pos=*/true, non_eos);
});
}
@ -448,21 +466,25 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
// If user provided a sample_func, use it.
if (runtime_config.sample_func) return runtime_config.sample_func;
static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1");
const size_t worker = 0; // TODO: parallelize
static const auto zone_top1 = ctx.profiler.AddZone("Gen.Sample Top1");
static const auto zone_topK = ctx.profiler.AddZone("Gen.Sample general");
// Fast path for top-1 with no accept_token.
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
return [&](size_t /*qi*/, Logits logits) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker, zone);
return Top1OfSoftmax(logits);
};
return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker)
HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker, zone_top1);
return Top1OfSoftmax(logits);
};
}
// General case: Softmax with top-k sampling.
return [&](size_t qi, Logits logits) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample general");
RngStream gen(engine, qi);
return [&](size_t qi, size_t pos, Logits logits,
size_t worker) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE3(ctx.profiler, worker, zone_topK);
// We want a different sequence for each batch element and position.
const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos;
RngStream gen(engine, stream);
return FusedSoftmaxAndSampleTopK(
logits, runtime_config.top_k, gen, runtime_config.temperature,
runtime_config.accept_token, ctx.profiler, worker);
@ -524,12 +546,13 @@ static void GenerateT(const ModelConfig& config,
// Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
const bool pos_plus_1 = false; // during prefill, pos is still correct.
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
// In autoregressive mode, we have not prefilled the last token, so do
// not advance.
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
runtime_config, qbatch, pos_plus_1, update_pos, non_eos);
StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f,
config, runtime_config, qbatch, update_pos, non_eos);
}
size_t max_gen_steps = runtime_config.max_generated_tokens;
@ -546,7 +569,7 @@ static void GenerateT(const ModelConfig& config,
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
Transformer(config, runtime_config, weights, activations, qbatch, env);
SampleAndStream(config, runtime_config, weights, sample_token, activations,
qbatch, /*update_pos=*/true, env, non_eos, timing_info);
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
}

View File

@ -89,10 +89,10 @@ using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate.
using AcceptFunc = std::function<bool(int, float)>;
// If not empty, SampleFunc is called with the query_idx and logits for the
// next token, which it may modify/overwrite. It returns the next generated
// token together with its probability.
using SampleFunc = std::function<TokenAndProb(size_t, Logits)>;
// If not empty, SampleFunc is called concurrently from worker thread(s) with
// query_idx, pos, logits for the next token (which it may modify/overwrite),
// and worker. It returns the next generated token and its probability.
using SampleFunc = std::function<TokenAndProb(size_t, size_t, Logits, size_t)>;
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
// - index of query within containing batch (if any); zero otherwise.
// - position in the tokens sequence
@ -115,6 +115,7 @@ using ActivationsObserverFunc =
struct RuntimeConfig {
// If non-null, `batch_stream_token` is called for each token in the batch,
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
// This is called sequentially from the main thread.
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
PROFILER_ZONE("Gen.StreamToken");
if (batch_stream_token) {