mirror of https://github.com/google/gemma.cpp.git
1.15x speedup: parallel sampling, enabled by new RNG
Also pass pos to SampleFunc, for seeding the RNG. PiperOrigin-RevId: 803453518
This commit is contained in:
parent
ad7d7a2713
commit
cbe24eac51
|
|
@ -99,14 +99,15 @@ HWY_EXPORT(CallSoftmax);
|
||||||
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||||
MatMulEnv& env, int verbosity) {
|
MatMulEnv& env, int verbosity) {
|
||||||
const StreamFunc stream_token = [](int, float) { return true; };
|
const BatchStreamFunc stream_token = [](size_t, size_t, int, float) {
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
const int vocab_size = gemma.Config().vocab_size;
|
const int vocab_size = gemma.Config().vocab_size;
|
||||||
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
|
float cross_entropy = std::log(vocab_size); // first token; == -log(1/v_s)
|
||||||
size_t pos = 1;
|
|
||||||
|
|
||||||
const SampleFunc sample_token = [&](size_t qi,
|
const SampleFunc sample_token = [&](size_t qi, size_t pos, Logits logits,
|
||||||
Logits logits) -> TokenAndProb {
|
size_t /*worker*/) -> TokenAndProb {
|
||||||
// input is logits, not yet probabilities
|
// input is logits, not yet probabilities
|
||||||
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler);
|
HWY_DYNAMIC_DISPATCH(CallSoftmax)(logits, env.ctx.profiler);
|
||||||
// We are called for each token, but pos starts at 1. Clamping
|
// We are called for each token, but pos starts at 1. Clamping
|
||||||
|
|
@ -128,7 +129,6 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||||
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
printf("Processed %zu tokens, cross-entropy per token: %f\n", pos + 1,
|
||||||
cross_entropy / std::log(2.0) / (pos + 1));
|
cross_entropy / std::log(2.0) / (pos + 1));
|
||||||
}
|
}
|
||||||
++pos;
|
|
||||||
return TokenAndProb{.token = token, .prob = prob};
|
return TokenAndProb{.token = token, .prob = prob};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -138,7 +138,7 @@ float ComputeCrossEntropy(const Gemma& gemma, size_t max_generated_tokens,
|
||||||
.max_generated_tokens = max_generated_tokens - 1,
|
.max_generated_tokens = max_generated_tokens - 1,
|
||||||
.temperature = 0.0f,
|
.temperature = 0.0f,
|
||||||
.verbosity = verbosity,
|
.verbosity = verbosity,
|
||||||
.stream_token = stream_token,
|
.batch_stream_token = stream_token,
|
||||||
.sample_func = sample_token,
|
.sample_func = sample_token,
|
||||||
};
|
};
|
||||||
TimingInfo timing_info;
|
TimingInfo timing_info;
|
||||||
|
|
|
||||||
|
|
@ -132,6 +132,7 @@ struct Activations {
|
||||||
x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)),
|
x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)),
|
||||||
logits(
|
logits(
|
||||||
MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)),
|
MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)),
|
||||||
|
sampled(MatFactory("sampled", batch_size, 3, ctx.allocator)),
|
||||||
|
|
||||||
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
|
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
|
||||||
config.model_dim, ctx.allocator)),
|
config.model_dim, ctx.allocator)),
|
||||||
|
|
@ -164,6 +165,7 @@ struct Activations {
|
||||||
x.OverrideRows(batch_size);
|
x.OverrideRows(batch_size);
|
||||||
x_bf.OverrideRows(batch_size);
|
x_bf.OverrideRows(batch_size);
|
||||||
logits.OverrideRows(batch_size);
|
logits.OverrideRows(batch_size);
|
||||||
|
sampled.OverrideRows(batch_size);
|
||||||
|
|
||||||
pre_ffw_rms_out.OverrideRows(batch_size);
|
pre_ffw_rms_out.OverrideRows(batch_size);
|
||||||
C1.OverrideRows(batch_size);
|
C1.OverrideRows(batch_size);
|
||||||
|
|
@ -178,6 +180,7 @@ struct Activations {
|
||||||
MatStorageT<float> x; // input
|
MatStorageT<float> x; // input
|
||||||
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
|
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
|
||||||
MatStorageT<float> logits;
|
MatStorageT<float> logits;
|
||||||
|
MatStorageT<uint32_t> sampled; // batch_size x 3 (padded)
|
||||||
|
|
||||||
// Gated FFW
|
// Gated FFW
|
||||||
MatStorageT<BF16> pre_ffw_rms_out;
|
MatStorageT<BF16> pre_ffw_rms_out;
|
||||||
|
|
|
||||||
|
|
@ -377,14 +377,13 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||||
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
|
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
|
||||||
// `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the
|
// `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the
|
||||||
// query is at the end of its sequence.
|
// query is at the end of its sequence.
|
||||||
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
static void StreamAndUpdateEOS(const size_t qi, size_t pos, int token,
|
||||||
const ModelConfig& config,
|
const float prob, const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
QBatch& qbatch, bool pos_plus_1, bool update_pos,
|
QBatch& qbatch, bool update_pos,
|
||||||
hwy::BitSet4096<>& non_eos) {
|
hwy::BitSet4096<>& non_eos) {
|
||||||
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
|
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
|
||||||
|
|
||||||
const size_t pos = qbatch.Pos(qi) + (pos_plus_1 ? 1 : 0);
|
|
||||||
if (HWY_UNLIKELY(
|
if (HWY_UNLIKELY(
|
||||||
!runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) {
|
!runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) {
|
||||||
// User decided to stop: set token to primary EOS to trigger IsEOS below.
|
// User decided to stop: set token to primary EOS to trigger IsEOS below.
|
||||||
|
|
@ -402,11 +401,13 @@ static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
||||||
|
|
||||||
// Must be called after Transformer: either after prefill, or during decode.
|
// Must be called after Transformer: either after prefill, or during decode.
|
||||||
// Computes logits, samples and streams the token.
|
// Computes logits, samples and streams the token.
|
||||||
static void SampleAndStream(
|
static void SampleAndStream(const ModelConfig& config,
|
||||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const WeightsPtrs& weights, const SampleFunc& sample_token,
|
const WeightsPtrs& weights,
|
||||||
Activations& activations, QBatch& qbatch, bool update_pos, MatMulEnv& env,
|
const SampleFunc& sample_token,
|
||||||
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
|
Activations& activations, QBatch& qbatch,
|
||||||
|
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
||||||
|
TimingInfo& timing_info) {
|
||||||
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
||||||
|
|
||||||
RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf,
|
RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf,
|
||||||
|
|
@ -429,16 +430,33 @@ static void SampleAndStream(
|
||||||
|
|
||||||
timing_info.NotifyGenerated(non_eos.Count());
|
timing_info.NotifyGenerated(non_eos.Count());
|
||||||
|
|
||||||
// TODO: parallelize
|
ParallelFor(
|
||||||
non_eos.Foreach([&](size_t qi) {
|
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
|
||||||
const TokenAndProb tp = sample_token(qi, activations.logits.RowSpan(qi));
|
/*cluster_idx=*/0, [&](size_t qi, size_t worker) {
|
||||||
|
if (!non_eos.Get(qi)) return;
|
||||||
|
|
||||||
// We streamed all prefill tokens, but pos is still one behind because we
|
// We streamed all prefill tokens, but pos is still one behind
|
||||||
// started generation at pos = prompt.size() - 1. We want the pos argument
|
// because we started generation at pos = prompt.size() - 1.
|
||||||
// to match the number of calls to `StreamToken`, as expected by the caller.
|
// We want the pos argument to match the number of calls to
|
||||||
const bool pos_plus_1 = true;
|
// `StreamToken`, as expected by the caller.
|
||||||
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
|
const size_t pos = qbatch.Pos(qi) + 1;
|
||||||
pos_plus_1, update_pos, non_eos);
|
|
||||||
|
const TokenAndProb tp =
|
||||||
|
sample_token(qi, pos, activations.logits.RowSpan(qi), worker);
|
||||||
|
// `sampled` is padded, which prevents false sharing.
|
||||||
|
activations.sampled.Row(qi)[0] = static_cast<uint32_t>(pos);
|
||||||
|
activations.sampled.Row(qi)[1] = static_cast<uint32_t>(tp.token);
|
||||||
|
activations.sampled.Row(qi)[2] = hwy::BitCastScalar<uint32_t>(tp.prob);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Sequentially, because `StreamToken` is not yet thread-safe.
|
||||||
|
non_eos.Foreach([&](size_t qi) {
|
||||||
|
const size_t pos = activations.sampled.Row(qi)[0];
|
||||||
|
const int token = static_cast<int>(activations.sampled.Row(qi)[1]);
|
||||||
|
const float prob =
|
||||||
|
hwy::BitCastScalar<float>(activations.sampled.Row(qi)[2]);
|
||||||
|
StreamAndUpdateEOS(qi, pos, token, prob, config, runtime_config, qbatch,
|
||||||
|
/*update_pos=*/true, non_eos);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -448,21 +466,25 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
|
||||||
// If user provided a sample_func, use it.
|
// If user provided a sample_func, use it.
|
||||||
if (runtime_config.sample_func) return runtime_config.sample_func;
|
if (runtime_config.sample_func) return runtime_config.sample_func;
|
||||||
|
|
||||||
static const auto zone = ctx.profiler.AddZone("Gen.Sample Top1");
|
static const auto zone_top1 = ctx.profiler.AddZone("Gen.Sample Top1");
|
||||||
const size_t worker = 0; // TODO: parallelize
|
static const auto zone_topK = ctx.profiler.AddZone("Gen.Sample general");
|
||||||
|
|
||||||
// Fast path for top-1 with no accept_token.
|
// Fast path for top-1 with no accept_token.
|
||||||
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
|
||||||
return [&](size_t /*qi*/, Logits logits) HWY_ATTR -> TokenAndProb {
|
return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker)
|
||||||
PROFILER_ZONE3(ctx.profiler, worker, zone);
|
HWY_ATTR -> TokenAndProb {
|
||||||
return Top1OfSoftmax(logits);
|
PROFILER_ZONE3(ctx.profiler, worker, zone_top1);
|
||||||
};
|
return Top1OfSoftmax(logits);
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// General case: Softmax with top-k sampling.
|
// General case: Softmax with top-k sampling.
|
||||||
return [&](size_t qi, Logits logits) HWY_ATTR -> TokenAndProb {
|
return [&](size_t qi, size_t pos, Logits logits,
|
||||||
PROFILER_ZONE("Gen.Sample general");
|
size_t worker) HWY_ATTR -> TokenAndProb {
|
||||||
RngStream gen(engine, qi);
|
PROFILER_ZONE3(ctx.profiler, worker, zone_topK);
|
||||||
|
// We want a different sequence for each batch element and position.
|
||||||
|
const uint64_t stream = (static_cast<uint64_t>(qi) << 32) | pos;
|
||||||
|
RngStream gen(engine, stream);
|
||||||
return FusedSoftmaxAndSampleTopK(
|
return FusedSoftmaxAndSampleTopK(
|
||||||
logits, runtime_config.top_k, gen, runtime_config.temperature,
|
logits, runtime_config.top_k, gen, runtime_config.temperature,
|
||||||
runtime_config.accept_token, ctx.profiler, worker);
|
runtime_config.accept_token, ctx.profiler, worker);
|
||||||
|
|
@ -524,12 +546,13 @@ static void GenerateT(const ModelConfig& config,
|
||||||
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
||||||
const bool pos_plus_1 = false; // during prefill, pos is still correct.
|
|
||||||
|
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
|
||||||
// In autoregressive mode, we have not prefilled the last token, so do
|
// In autoregressive mode, we have not prefilled the last token, so do
|
||||||
// not advance.
|
// not advance.
|
||||||
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
|
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
|
||||||
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
|
StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f,
|
||||||
runtime_config, qbatch, pos_plus_1, update_pos, non_eos);
|
config, runtime_config, qbatch, update_pos, non_eos);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||||
|
|
@ -546,7 +569,7 @@ static void GenerateT(const ModelConfig& config,
|
||||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||||
SampleAndStream(config, runtime_config, weights, sample_token, activations,
|
SampleAndStream(config, runtime_config, weights, sample_token, activations,
|
||||||
qbatch, /*update_pos=*/true, env, non_eos, timing_info);
|
qbatch, env, non_eos, timing_info);
|
||||||
}
|
}
|
||||||
timing_info.NotifyGenerateDone();
|
timing_info.NotifyGenerateDone();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -89,10 +89,10 @@ using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
||||||
// If not empty, AcceptFunc is called with token. It should return false for
|
// If not empty, AcceptFunc is called with token. It should return false for
|
||||||
// tokens you don't want to generate and true for tokens you want to generate.
|
// tokens you don't want to generate and true for tokens you want to generate.
|
||||||
using AcceptFunc = std::function<bool(int, float)>;
|
using AcceptFunc = std::function<bool(int, float)>;
|
||||||
// If not empty, SampleFunc is called with the query_idx and logits for the
|
// If not empty, SampleFunc is called concurrently from worker thread(s) with
|
||||||
// next token, which it may modify/overwrite. It returns the next generated
|
// query_idx, pos, logits for the next token (which it may modify/overwrite),
|
||||||
// token together with its probability.
|
// and worker. It returns the next generated token and its probability.
|
||||||
using SampleFunc = std::function<TokenAndProb(size_t, Logits)>;
|
using SampleFunc = std::function<TokenAndProb(size_t, size_t, Logits, size_t)>;
|
||||||
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||||
// - index of query within containing batch (if any); zero otherwise.
|
// - index of query within containing batch (if any); zero otherwise.
|
||||||
// - position in the tokens sequence
|
// - position in the tokens sequence
|
||||||
|
|
@ -115,6 +115,7 @@ using ActivationsObserverFunc =
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
// If non-null, `batch_stream_token` is called for each token in the batch,
|
// If non-null, `batch_stream_token` is called for each token in the batch,
|
||||||
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
|
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
|
||||||
|
// This is called sequentially from the main thread.
|
||||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||||
PROFILER_ZONE("Gen.StreamToken");
|
PROFILER_ZONE("Gen.StreamToken");
|
||||||
if (batch_stream_token) {
|
if (batch_stream_token) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue