mirror of https://github.com/google/gemma.cpp.git
Fix pos assertions, refs #665
Ensure the streaming func pos matches the number of calls. Add two arguments that control pos+1 and pos+=1 behavior. Also cleanup/add comments. run: use batch_stream_func, add assert, higher verbosity for MM autotune output PiperOrigin-RevId: 799511163
This commit is contained in:
parent
9bf0fe4e37
commit
ed2f0bd1b0
|
|
@ -138,6 +138,10 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
// Reset the `response` string here, then check that the model actually has
|
||||
// access to the previous turn by asking to reproduce.
|
||||
response.clear();
|
||||
// -1 because our prefill does not generate KVs for the last token. Do not
|
||||
// just pass abs_pos - 1 because our callback checks pos == abs_pos.
|
||||
HWY_ASSERT(abs_pos > 0);
|
||||
--abs_pos;
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
s_env->MutableEnv(), timing_info);
|
||||
fprintf(stderr, "decoded: '%s'\n", response.c_str());
|
||||
|
|
|
|||
113
gemma/gemma.cc
113
gemma/gemma.cc
|
|
@ -127,9 +127,10 @@ static float EmbeddingScaling(size_t model_dim) {
|
|||
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
// `pos` is the *token*'s position, not the start of the batch, because this is
|
||||
// called for batches of tokens in prefill, but batches of queries in decode.
|
||||
// `x_row` indicates which row of `x` to write to.
|
||||
// `pos` is the *token*'s position for `AddAbsolutePositionalEmbeddings`, not
|
||||
// the start of the batch, because this is called for batches of tokens in
|
||||
// prefill, but batches of queries in decode.
|
||||
//
|
||||
// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3
|
||||
// spec) until we run out of image tokens. This allows for a multi-image prompt
|
||||
|
|
@ -137,7 +138,7 @@ static float EmbeddingScaling(size_t model_dim) {
|
|||
// calling application.
|
||||
// Returns new image_token_position.
|
||||
static HWY_NOINLINE size_t
|
||||
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||
EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
|
||||
const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||
MatStorageT<float>& x, ThreadingContext& ctx,
|
||||
const ImageTokens* image_tokens = nullptr,
|
||||
|
|
@ -146,14 +147,14 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
|||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||
image_tokens != nullptr && token == -2 &&
|
||||
image_token_position < image_tokens->Rows()) {
|
||||
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi),
|
||||
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(x_row),
|
||||
x.Cols() * x.ElementBytes());
|
||||
return image_token_position + 1;
|
||||
}
|
||||
|
||||
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
|
||||
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
|
||||
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi),
|
||||
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(x_row),
|
||||
x.Cols() * x.ElementBytes());
|
||||
return image_token_position;
|
||||
}
|
||||
|
|
@ -174,14 +175,14 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
|||
const auto embedding_span =
|
||||
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row),
|
||||
model_dim);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim,
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim,
|
||||
ctx.profiler, worker);
|
||||
});
|
||||
|
||||
if (model_config.absolute_pe) {
|
||||
AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos);
|
||||
AddAbsolutePositionalEmbeddings(x.Row(x_row), model_dim, pos);
|
||||
}
|
||||
return image_token_position;
|
||||
}
|
||||
|
|
@ -249,24 +250,12 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
|||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const size_t pos = qbatch_1.Pos(0) + ti;
|
||||
const size_t pos_in_prompt = tbatch_start + ti;
|
||||
HWY_DASSERT(pos_in_prompt < prompt_size);
|
||||
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||
image_token_position = EmbedMMToken(
|
||||
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
||||
env.ctx, runtime_config.image_tokens, image_token_position);
|
||||
}
|
||||
|
||||
// Transformer with one batch of tokens from a single query.
|
||||
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
|
||||
++layer_idx) {
|
||||
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, qbatch_1, env);
|
||||
}
|
||||
|
||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const size_t pos = qbatch_1.Pos(0) + ti;
|
||||
const size_t pos_in_prompt = tbatch_start + ti;
|
||||
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||
if (pos_in_prompt < prompt_size - 1) {
|
||||
runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f);
|
||||
} else {
|
||||
|
|
@ -276,6 +265,14 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
|||
}
|
||||
}
|
||||
|
||||
// Transformer with one batch of tokens from a single query. No need to
|
||||
// set `PrevToken` because we already did the embedding above.
|
||||
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
|
||||
++layer_idx) {
|
||||
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, qbatch_1, env);
|
||||
}
|
||||
|
||||
qbatch_1.MutablePos(0) += tbatch_size;
|
||||
} // for tbatch_start
|
||||
if (attend_to_last_token) {
|
||||
|
|
@ -291,8 +288,8 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
|||
}
|
||||
|
||||
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
|
||||
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
|
||||
// token-batched `PrefillTBatch`.
|
||||
// Called by query-batched `PrefillQBatch` and `GenerateT`, but not the
|
||||
// token-batched `PrefillTBatch`, which supports image embedding.
|
||||
static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights,
|
||||
|
|
@ -324,8 +321,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
|||
}
|
||||
}
|
||||
|
||||
// Populates KV cache for the batch queries, one token at a time. Only called
|
||||
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
|
||||
// Populates KV cache for the batch queries, one token at a time.
|
||||
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
|
|
@ -337,6 +333,8 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
|||
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
non_eos.Set(qi);
|
||||
|
||||
// Should only be called for autoregressive (non-prefix-LM) prefill.
|
||||
HWY_DASSERT(qbatch.PrefixEnd(qi) == 0);
|
||||
}
|
||||
|
||||
|
|
@ -358,7 +356,7 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
|||
}
|
||||
|
||||
// The input (PrevToken) is one token from each query in the batch.
|
||||
// Do not call DecodeStepT because it computes logits for token
|
||||
// Do not call `SampleAndStream` because it computes logits for token
|
||||
// probabilities, which are not required for the prompt tokens.
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
}
|
||||
|
|
@ -369,42 +367,40 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
|||
}
|
||||
|
||||
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
|
||||
// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the
|
||||
// `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the
|
||||
// query is at the end of its sequence.
|
||||
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
QBatch& qbatch, hwy::BitSet4096<>& non_eos) {
|
||||
QBatch& qbatch, bool pos_plus_1, bool update_pos,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
|
||||
|
||||
if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi),
|
||||
qbatch.Pos(qi), token, prob))) {
|
||||
const size_t pos = qbatch.Pos(qi) + (pos_plus_1 ? 1 : 0);
|
||||
if (HWY_UNLIKELY(
|
||||
!runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) {
|
||||
// User decided to stop: set token to primary EOS to trigger IsEOS below.
|
||||
token = config.eos_id;
|
||||
HWY_DASSERT(config.IsEOS(token));
|
||||
}
|
||||
|
||||
qbatch.PrevToken(qi) = token;
|
||||
qbatch.MutablePos(qi) += 1;
|
||||
qbatch.MutablePos(qi) += update_pos ? 1 : 0;
|
||||
|
||||
// Primary or secondary EOS: mark query as EOS, but still increment (for
|
||||
// multi-turn, we should still keep the prior EOS).
|
||||
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
|
||||
}
|
||||
|
||||
// For a batch of queries, runs Transformer, computes logits, samples and
|
||||
// streams the token.
|
||||
static void DecodeStepT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights,
|
||||
const SampleFunc& sample_token,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
||||
TimingInfo& timing_info) {
|
||||
// Must be called after Transformer: either after prefill, or during decode.
|
||||
// Computes logits, samples and streams the token.
|
||||
static void SampleAndStream(
|
||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||
const WeightsPtrs& weights, const SampleFunc& sample_token,
|
||||
Activations& activations, QBatch& qbatch, bool update_pos, MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
|
||||
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
||||
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
|
||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
|
||||
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
|
|
@ -427,8 +423,12 @@ static void DecodeStepT(const ModelConfig& config,
|
|||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||
timing_info.NotifyGenerated();
|
||||
|
||||
// We streamed all prefill tokens, but pos is still one behind because we
|
||||
// started generation at pos = prompt.size() - 1. We want the pos argument
|
||||
// to match the number of calls to `StreamToken`, as expected by the caller.
|
||||
const bool pos_plus_1 = true;
|
||||
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
|
||||
non_eos);
|
||||
pos_plus_1, update_pos, non_eos);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -476,15 +476,16 @@ static void GenerateT(const ModelConfig& config,
|
|||
const size_t seq_len = qbatch.KV(0).SeqLen();
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
const PromptTokens& prompt = qbatch.Prompt(qi);
|
||||
// Sanity check: prompts should not be empty. Note that multi-turn prompts
|
||||
// start with <end_of_turn>.
|
||||
HWY_ASSERT(prompt.size() != 0);
|
||||
|
||||
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
|
||||
|
||||
// Prefill stops before size - 1 because the last prompt token is the
|
||||
// first input token for generation.
|
||||
total_prefill_tokens += prompt.size() - 1;
|
||||
|
||||
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
||||
|
||||
all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0;
|
||||
|
||||
// We use a single divisor, so all sequence lengths must be the same.
|
||||
|
|
@ -518,14 +519,12 @@ static void GenerateT(const ModelConfig& config,
|
|||
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
||||
const bool pos_plus_1 = false; // during prefill, pos is still correct.
|
||||
// In autoregressive mode, we have not prefilled the last token, so do
|
||||
// not advance.
|
||||
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
|
||||
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
|
||||
runtime_config, qbatch, non_eos);
|
||||
// StreamAndUpdateEOS() sets the stream position one token too far in
|
||||
// autoregressive mode.
|
||||
const bool attend_to_last_token = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
|
||||
if (!attend_to_last_token) {
|
||||
qbatch.MutablePos(qi) -= 1;
|
||||
}
|
||||
runtime_config, qbatch, pos_plus_1, update_pos, non_eos);
|
||||
}
|
||||
|
||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||
|
|
@ -540,8 +539,10 @@ static void GenerateT(const ModelConfig& config,
|
|||
{
|
||||
timing_info.generate_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||
DecodeStepT(config, runtime_config, weights, sample_token, activations,
|
||||
qbatch, env, non_eos, timing_info);
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
SampleAndStream(config, runtime_config, weights, sample_token,
|
||||
activations, qbatch, /*update_pos=*/true, env, non_eos,
|
||||
timing_info);
|
||||
}
|
||||
timing_info.NotifyGenerateDone();
|
||||
}
|
||||
|
|
|
|||
16
gemma/run.cc
16
gemma/run.cc
|
|
@ -132,7 +132,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
}
|
||||
|
||||
// callback function invoked for each generated token.
|
||||
auto stream_token = [&](int token, float) {
|
||||
auto batch_stream_token = [&](size_t query_idx, size_t pos, int token,
|
||||
float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
|
||||
HWY_ASSERT(pos == abs_pos);
|
||||
++abs_pos;
|
||||
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||
|
|
@ -148,8 +153,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
}
|
||||
return true;
|
||||
}
|
||||
std::string token_text;
|
||||
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
if (first_response_token) {
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
if (inference.verbosity >= 1) {
|
||||
|
|
@ -187,7 +190,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
.stream_token = stream_token,
|
||||
.batch_stream_token = batch_stream_token,
|
||||
.use_spinning = threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
std::vector<int> prompt;
|
||||
|
|
@ -223,6 +226,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
if (inference.verbosity >= 1) {
|
||||
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
||||
}
|
||||
// -1 because our prefill does not generate KVs for the last token. Do not
|
||||
// just pass abs_pos - 1 because our callback checks pos == abs_pos.
|
||||
if (abs_pos > 0) --abs_pos;
|
||||
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env,
|
||||
timing_info);
|
||||
std::cout << "\n\n";
|
||||
|
|
@ -255,7 +261,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
|
||||
ThreadingContext ctx(threading);
|
||||
MatMulEnv env(ctx);
|
||||
if (inference.verbosity >= 2) env.print_best = true;
|
||||
if (inference.verbosity >= 3) env.print_best = true;
|
||||
const Gemma gemma(loader, inference, ctx);
|
||||
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue