mirror of https://github.com/google/gemma.cpp.git
1.15x speedup: parallel sampling, enabled by new RNG
Also pass pos to SampleFunc, for seeding the RNG. PiperOrigin-RevId: 803453518
This commit is contained in:
parent
ad7d7a2713
commit
cbe24eac51
|
|
@ -99,14 +99,15 @@ HWY_EXPORT(CallSoftmax);
|
|||
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
MatMulEnv& env, int verbosity) {
|
||||
const StreamFunc stream_token = [](int, float) { return true; };
|
||||
const BatchStreamFunc stream_token = [](size_t, size_t, int, float) {
|
||||
return true;
|
||||
};
|
||||
|
||||
const int vocab_size = gemma.Config().vocab_size;
|
||||
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
|
||||
size_t pos = 1;
|
||||
|
||||
const SampleFunc sample_token = [&](size_t qi,
|
||||
Logits logits) -> TokenAndProb {
|
||||
const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits,
|
||||
size_t /*worker*/) -> TokenAndProb {
|
||||
// input is logits, not yet probabilities
|
||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler);
|
||||
// We are called for each token, but pos starts at 1. Clamping
|
||||
|
|
@ -128,7 +129,6 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
|||
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
||||
cross_entropy / std::log(2.0) / (pos + 1));
|
||||
}
|
||||
++pos;
|
||||
return TokenAndProb{.token = token, .prob = prob};
|
||||
};
|
||||
|
||||
|
|
@ -138,7 +138,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
|||
.max_generated_tokens = max_generated_tokens - 1,
|
||||
.temperature = 0.0f,
|
||||
.verbosity = verbosity,
|
||||
.stream_token = stream_token,
|
||||
.batch_stream_token = stream_token,
|
||||
.sample_func = sample_token,
|
||||
};
|
||||
TimingInfo timing_info;
|
||||
|
|
|
|||
|
|
@ -132,6 +132,7 @@ struct Activations {
|
|||
x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)),
|
||||
logits(
|
||||
MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)),
|
||||
sampled(MatFactory("sampled", batch_size, 3, ctx.allocator)),
|
||||
|
||||
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
|
||||
config.model_dim, ctx.allocator)),
|
||||
|
|
@ -164,6 +165,7 @@ struct Activations {
|
|||
x.OverrideRows(batch_size);
|
||||
x_bf.OverrideRows(batch_size);
|
||||
logits.OverrideRows(batch_size);
|
||||
sampled.OverrideRows(batch_size);
|
||||
|
||||
pre_ffw_rms_out.OverrideRows(batch_size);
|
||||
C1.OverrideRows(batch_size);
|
||||
|
|
@ -178,6 +180,7 @@ struct Activations {
|
|||
MatStorageT<float> x; // input
|
||||
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
|
||||
MatStorageT<float> logits;
|
||||
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
|
||||
|
||||
// Gated FFW
|
||||
MatStorageT<BF16> pre_ffw_rms_out;
|
||||
|
|
|
|||
|
|
@ -377,14 +377,13 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
|||
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
|
||||
// `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the
|
||||
// query is at the end of its sequence.
|
||||
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
||||
const ModelConfig& config,
|
||||
static void StreamAndUpdateEOS(const size_t qi, size_t pos, int token,
|
||||
const float prob, const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
QBatch& qbatch, bool pos_plus_1, bool update_pos,
|
||||
QBatch& qbatch, bool update_pos,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
|
||||
|
||||
const size_t pos = qbatch.Pos(qi) + (pos_plus_1 ? 1 : 0);
|
||||
if (HWY_UNLIKELY(
|
||||
!runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) {
|
||||
// User decided to stop: set token to primary EOS to trigger IsEOS below.
|
||||
|
|
@ -402,11 +401,13 @@ static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
|||
|
||||
// Must be called after Transformer: either after prefill, or during decode.
|
||||
// Computes logits, samples and streams the token.
|
||||
static void SampleAndStream(
|
||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights, const SampleFunc& sample_token,
|
||||
Activations& activations, QBatch& qbatch, bool update_pos, MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
|
||||
static void SampleAndStream(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights,
|
||||
const SampleFunc& sample_token,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
||||
|
||||
RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf,
|
||||
|
|
@ -429,16 +430,33 @@ static void SampleAndStream(
|
|||
|
||||
timing_info.NotifyGenerated(non_eos.Count());
|
||||
|
||||
// TODO: parallelize
|
||||
non_eos.Foreach([&](size_t qi) {
|
||||
const TokenAndProb tp = sample_token(qi, activations.logits.RowSpan(qi));
|
||||
ParallelFor(
|
||||
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
|
||||
/*cluster_idx=*/0, [&](size_t qi, size_t worker) {
|
||||
if (!non_eos.Get(qi)) return;
|
||||
|
||||
// We streamed all prefill tokens, but pos is still one behind because we
|
||||
// started generation at pos = prompt.size() - 1. We want the pos argument
|
||||
// to match the number of calls to `StreamToken`, as expected by the caller.
|
||||
const bool pos_plus_1 = true;
|
||||
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
|
||||
pos_plus_1, update_pos, non_eos);
|
||||
// We streamed all prefill tokens, but pos is still one behind
|
||||
// because we started generation at pos = prompt.size() - 1.
|
||||
// We want the pos argument to match the number of calls to
|
||||
// `StreamToken`, as expected by the caller.
|
||||
const size_t pos = qbatch.Pos(qi) + 1;
|
||||
|
||||
const TokenAndProb tp =
|
||||
sample_token(qi, pos, activations.logits.RowSpan(qi), worker);
|
||||
// `sampled` is padded, which prevents false sharing.
|
||||
activations.sampled.Row(qi)[0] = static_cast<uint32_t>(pos);
|
||||
activations.sampled.Row(qi)[1] = static_cast<uint32_t>(tp.token);
|
||||
activations.sampled.Row(qi)[2] = hwy::BitCastScalar<uint32_t>(tp.prob);
|
||||
});
|
||||
|
||||
// Sequentially, because `StreamToken` is not yet thread-safe.
|
||||
non_eos.Foreach([&](size_t qi) {
|
||||
const size_t pos = activations.sampled.Row(qi)[0];
|
||||
const int token = static_cast<int>(activations.sampled.Row(qi)[1]);
|
||||
const float prob =
|
||||
hwy::BitCastScalar<float>(activations.sampled.Row(qi)[2]);
|
||||
StreamAndUpdateEOS(qi, pos, token, prob, config, runtime_config, qbatch,
|
||||
/*update_pos=*/true, non_eos);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -448,21 +466,25 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
|
|||
// If user provided a sample_func, use it.
|
||||
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||
|
||||
static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1");
|
||||
const size_t worker = 0; // TODO: parallelize
|
||||
static const auto zone_top1 = ctx.profiler.AddZone("Gen.Sample Top1");
|
||||
static const auto zone_topK = ctx.profiler.AddZone("Gen.Sample general");
|
||||
|
||||
// Fast path for top-1 with no accept_token.
|
||||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||
return [&](size_t /*qi*/, Logits logits) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
||||
return Top1OfSoftmax(logits);
|
||||
};
|
||||
return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker)
|
||||
HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone_top1);
|
||||
return Top1OfSoftmax(logits);
|
||||
};
|
||||
}
|
||||
|
||||
// General case: Softmax with top-k sampling.
|
||||
return [&](size_t qi, Logits logits) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE("Gen.Sample general");
|
||||
RngStream gen(engine, qi);
|
||||
return [&](size_t qi, size_t pos, Logits logits,
|
||||
size_t worker) HWY_ATTR -> TokenAndProb {
|
||||
PROFILER_ZONE3(ctx.profiler, worker, zone_topK);
|
||||
// We want a different sequence for each batch element and position.
|
||||
const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos;
|
||||
RngStream gen(engine, stream);
|
||||
return FusedSoftmaxAndSampleTopK(
|
||||
logits, runtime_config.top_k, gen, runtime_config.temperature,
|
||||
runtime_config.accept_token, ctx.profiler, worker);
|
||||
|
|
@ -524,12 +546,13 @@ static void GenerateT(const ModelConfig& config,
|
|||
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
||||
const bool pos_plus_1 = false; // during prefill, pos is still correct.
|
||||
|
||||
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
|
||||
// In autoregressive mode, we have not prefilled the last token, so do
|
||||
// not advance.
|
||||
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
|
||||
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
|
||||
runtime_config, qbatch, pos_plus_1, update_pos, non_eos);
|
||||
StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f,
|
||||
config, runtime_config, qbatch, update_pos, non_eos);
|
||||
}
|
||||
|
||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||
|
|
@ -546,7 +569,7 @@ static void GenerateT(const ModelConfig& config,
|
|||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
SampleAndStream(config, runtime_config, weights, sample_token, activations,
|
||||
qbatch, /*update_pos=*/true, env, non_eos, timing_info);
|
||||
qbatch, env, non_eos, timing_info);
|
||||
}
|
||||
timing_info.NotifyGenerateDone();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,10 +89,10 @@ using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
|||
// If not empty, AcceptFunc is called with token. It should return false for
|
||||
// tokens you don't want to generate and true for tokens you want to generate.
|
||||
using AcceptFunc = std::function<bool(int, float)>;
|
||||
// If not empty, SampleFunc is called with the query_idx and logits for the
|
||||
// next token, which it may modify/overwrite. It returns the next generated
|
||||
// token together with its probability.
|
||||
using SampleFunc = std::function<TokenAndProb(size_t, Logits)>;
|
||||
// If not empty, SampleFunc is called concurrently from worker thread(s) with
|
||||
// query_idx, pos, logits for the next token (which it may modify/overwrite),
|
||||
// and worker. It returns the next generated token and its probability.
|
||||
using SampleFunc = std::function<TokenAndProb(size_t, size_t, Logits, size_t)>;
|
||||
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||
// - index of query within containing batch (if any); zero otherwise.
|
||||
// - position in the tokens sequence
|
||||
|
|
@ -115,6 +115,7 @@ using ActivationsObserverFunc =
|
|||
struct RuntimeConfig {
|
||||
// If non-null, `batch_stream_token` is called for each token in the batch,
|
||||
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
|
||||
// This is called sequentially from the main thread.
|
||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||
PROFILER_ZONE("Gen.StreamToken");
|
||||
if (batch_stream_token) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue