From ed2f0bd1b0162109c5efe70edd77dbe2709ade7f Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 26 Aug 2025 04:50:06 -0700 Subject: [PATCH] Fix pos assertions, refs #665 Ensure the streaming func pos matches the number of calls. Add two arguments that control pos+1 and pos+=1 behavior. Also cleanup/add comments. run: use batch_stream_func, add assert, higher verbosity for MM autotune output PiperOrigin-RevId: 799511163 --- evals/gemma_test.cc | 4 ++ gemma/gemma.cc | 115 ++++++++++++++++++++++---------------------- gemma/run.cc | 16 ++++-- 3 files changed, 73 insertions(+), 62 deletions(-) diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 12080f9..77efbae 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -138,6 +138,10 @@ TEST_F(GemmaTest, Multiturn) { // Reset the `response` string here, then check that the model actually has // access to the previous turn by asking to reproduce. response.clear(); + // -1 because our prefill does not generate KVs for the last token. Do not + // just pass abs_pos - 1 because our callback checks pos == abs_pos. + HWY_ASSERT(abs_pos > 0); + --abs_pos; model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), s_env->MutableEnv(), timing_info); fprintf(stderr, "decoded: '%s'\n", response.c_str()); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 95f52b8..b506e75 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -127,9 +127,10 @@ static float EmbeddingScaling(size_t model_dim) { hwy::ConvertScalarTo(sqrtf(static_cast(model_dim)))); } -// `batch_idx` indicates which row of `x` to write to. -// `pos` is the *token*'s position, not the start of the batch, because this is -// called for batches of tokens in prefill, but batches of queries in decode. +// `x_row` indicates which row of `x` to write to. +// `pos` is the *token*'s position for `AddAbsolutePositionalEmbeddings`, not +// the start of the batch, because this is called for batches of tokens in +// prefill, but batches of queries in decode. // // For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3 // spec) until we run out of image tokens. This allows for a multi-image prompt @@ -137,7 +138,7 @@ static float EmbeddingScaling(size_t model_dim) { // calling application. // Returns new image_token_position. static HWY_NOINLINE size_t -EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, +EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, const ModelConfig& model_config, const WeightsPtrs& weights, MatStorageT& x, ThreadingContext& ctx, const ImageTokens* image_tokens = nullptr, @@ -146,14 +147,14 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, if (model_config.wrapping == PromptWrapping::GEMMA_VLM && image_tokens != nullptr && token == -2 && image_token_position < image_tokens->Rows()) { - hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi), + hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(x_row), x.Cols() * x.ElementBytes()); return image_token_position + 1; } if (model_config.wrapping == PromptWrapping::PALIGEMMA && image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) { - hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi), + hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(x_row), x.Cols() * x.ElementBytes()); return image_token_position; } @@ -174,14 +175,14 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, const auto embedding_span = MakeSpan(weights_t->Row(0), embedding_ofs + model_dim); const hn::ScalableTag df; - DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi), + DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row), model_dim); - MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, + MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim, ctx.profiler, worker); }); if (model_config.absolute_pe) { - AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos); + AddAbsolutePositionalEmbeddings(x.Row(x_row), model_dim, pos); } return image_token_position; } @@ -249,24 +250,12 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, for (size_t ti = 0; ti < tbatch_size; ++ti) { const size_t pos = qbatch_1.Pos(0) + ti; const size_t pos_in_prompt = tbatch_start + ti; + HWY_DASSERT(pos_in_prompt < prompt_size); const int token = qbatch_1.Prompt(0)[pos_in_prompt]; image_token_position = EmbedMMToken( token, ti, pos, pos_in_prompt, config, weights, activations.x, env.ctx, runtime_config.image_tokens, image_token_position); - } - - // Transformer with one batch of tokens from a single query. - for (size_t layer_idx = 0; layer_idx < config.layer_configs.size(); - ++layer_idx) { - TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx), - activations, qbatch_1, env); - } - - // NOTE: we unconditionally call StreamToken, even if EOS. - for (size_t ti = 0; ti < tbatch_size; ++ti) { - const size_t pos = qbatch_1.Pos(0) + ti; - const size_t pos_in_prompt = tbatch_start + ti; - const int token = qbatch_1.Prompt(0)[pos_in_prompt]; + // NOTE: we unconditionally call StreamToken, even if EOS. if (pos_in_prompt < prompt_size - 1) { runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f); } else { @@ -276,6 +265,14 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, } } + // Transformer with one batch of tokens from a single query. No need to + // set `PrevToken` because we already did the embedding above. + for (size_t layer_idx = 0; layer_idx < config.layer_configs.size(); + ++layer_idx) { + TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx), + activations, qbatch_1, env); + } + qbatch_1.MutablePos(0) += tbatch_size; } // for tbatch_start if (attend_to_last_token) { @@ -291,8 +288,8 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, } // Embeds PrevToken (one from each query) and calls each TransformerLayer. -// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the -// token-batched `PrefillTBatch`. +// Called by query-batched `PrefillQBatch` and `GenerateT`, but not the +// token-batched `PrefillTBatch`, which supports image embedding. static HWY_NOINLINE void Transformer(const ModelConfig& config, const RuntimeConfig& runtime_config, const WeightsPtrs& weights, @@ -324,8 +321,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config, } } -// Populates KV cache for the batch queries, one token at a time. Only called -// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0. +// Populates KV cache for the batch queries, one token at a time. static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, const ModelConfig& config, const RuntimeConfig& runtime_config, @@ -337,6 +333,8 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, for (size_t qi = 0; qi < qbatch.Size(); ++qi) { non_eos.Set(qi); + + // Should only be called for autoregressive (non-prefix-LM) prefill. HWY_DASSERT(qbatch.PrefixEnd(qi) == 0); } @@ -358,7 +356,7 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, } // The input (PrevToken) is one token from each query in the batch. - // Do not call DecodeStepT because it computes logits for token + // Do not call `SampleAndStream` because it computes logits for token // probabilities, which are not required for the prompt tokens. Transformer(config, runtime_config, weights, activations, qbatch, env); } @@ -369,42 +367,40 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, } // Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent -// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the +// `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the // query is at the end of its sequence. static void StreamAndUpdateEOS(const size_t qi, int token, const float prob, const ModelConfig& config, const RuntimeConfig& runtime_config, - QBatch& qbatch, hwy::BitSet4096<>& non_eos) { + QBatch& qbatch, bool pos_plus_1, bool update_pos, + hwy::BitSet4096<>& non_eos) { HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called. - if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi), - qbatch.Pos(qi), token, prob))) { + const size_t pos = qbatch.Pos(qi) + (pos_plus_1 ? 1 : 0); + if (HWY_UNLIKELY( + !runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) { // User decided to stop: set token to primary EOS to trigger IsEOS below. token = config.eos_id; HWY_DASSERT(config.IsEOS(token)); } qbatch.PrevToken(qi) = token; - qbatch.MutablePos(qi) += 1; + qbatch.MutablePos(qi) += update_pos ? 1 : 0; // Primary or secondary EOS: mark query as EOS, but still increment (for // multi-turn, we should still keep the prior EOS). if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi); } -// For a batch of queries, runs Transformer, computes logits, samples and -// streams the token. -static void DecodeStepT(const ModelConfig& config, - const RuntimeConfig& runtime_config, - const WeightsPtrs& weights, - const SampleFunc& sample_token, - Activations& activations, QBatch& qbatch, - MatMulEnv& env, hwy::BitSet4096<>& non_eos, - TimingInfo& timing_info) { +// Must be called after Transformer: either after prefill, or during decode. +// Computes logits, samples and streams the token. +static void SampleAndStream( + const ModelConfig& config, const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, const SampleFunc& sample_token, + Activations& activations, QBatch& qbatch, bool update_pos, MatMulEnv& env, + hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) { HWY_DASSERT(qbatch.Size() == activations.x.Rows()); - Transformer(config, runtime_config, weights, activations, qbatch, env); - RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx); if (HWY_UNLIKELY(runtime_config.activations_observer)) { @@ -427,8 +423,12 @@ static void DecodeStepT(const ModelConfig& config, const TokenAndProb tp = sample_token(logits, config.vocab_size); timing_info.NotifyGenerated(); + // We streamed all prefill tokens, but pos is still one behind because we + // started generation at pos = prompt.size() - 1. We want the pos argument + // to match the number of calls to `StreamToken`, as expected by the caller. + const bool pos_plus_1 = true; StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch, - non_eos); + pos_plus_1, update_pos, non_eos); }); } @@ -476,15 +476,16 @@ static void GenerateT(const ModelConfig& config, const size_t seq_len = qbatch.KV(0).SeqLen(); for (size_t qi = 0; qi < qbatch.Size(); ++qi) { const PromptTokens& prompt = qbatch.Prompt(qi); + // Sanity check: prompts should not be empty. Note that multi-turn prompts + // start with . + HWY_ASSERT(prompt.size() != 0); + max_prompt_size = HWY_MAX(max_prompt_size, prompt.size()); // Prefill stops before size - 1 because the last prompt token is the // first input token for generation. total_prefill_tokens += prompt.size() - 1; - // Sanity check: prompts should not be empty, nor start with EOS. - HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); - all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0; // We use a single divisor, so all sequence lengths must be the same. @@ -518,14 +519,12 @@ static void GenerateT(const ModelConfig& config, // Stream the last prompt token from each query, fill activations.gen_tokens. for (size_t qi = 0; qi < qbatch.Size(); ++qi) { const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); + const bool pos_plus_1 = false; // during prefill, pos is still correct. + // In autoregressive mode, we have not prefilled the last token, so do + // not advance. + const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, - runtime_config, qbatch, non_eos); - // StreamAndUpdateEOS() sets the stream position one token too far in - // autoregressive mode. - const bool attend_to_last_token = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); - if (!attend_to_last_token) { - qbatch.MutablePos(qi) -= 1; - } + runtime_config, qbatch, pos_plus_1, update_pos, non_eos); } size_t max_gen_steps = runtime_config.max_generated_tokens; @@ -540,8 +539,10 @@ static void GenerateT(const ModelConfig& config, { timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { - DecodeStepT(config, runtime_config, weights, sample_token, activations, - qbatch, env, non_eos, timing_info); + Transformer(config, runtime_config, weights, activations, qbatch, env); + SampleAndStream(config, runtime_config, weights, sample_token, + activations, qbatch, /*update_pos=*/true, env, non_eos, + timing_info); } timing_info.NotifyGenerateDone(); } diff --git a/gemma/run.cc b/gemma/run.cc index 286c6ee..3915bf8 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -132,7 +132,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, } // callback function invoked for each generated token. - auto stream_token = [&](int token, float) { + auto batch_stream_token = [&](size_t query_idx, size_t pos, int token, + float) { + std::string token_text; + HWY_ASSERT(gemma.Tokenizer().Decode(std::vector{token}, &token_text)); + + HWY_ASSERT(pos == abs_pos); ++abs_pos; const bool in_prompt = tokens_generated_this_turn < prompt_size; const bool first_response_token = tokens_generated_this_turn == prompt_size; @@ -148,8 +153,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, } return true; } - std::string token_text; - HWY_ASSERT(gemma.Tokenizer().Decode(std::vector{token}, &token_text)); if (first_response_token) { token_text.erase(0, token_text.find_first_not_of(" \t\n")); if (inference.verbosity >= 1) { @@ -187,7 +190,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, TimingInfo timing_info = {.verbosity = inference.verbosity}; RuntimeConfig runtime_config = {.gen = &gen, .verbosity = inference.verbosity, - .stream_token = stream_token, + .batch_stream_token = batch_stream_token, .use_spinning = threading.spin}; inference.CopyTo(runtime_config); std::vector prompt; @@ -223,6 +226,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, if (inference.verbosity >= 1) { std::cerr << "\n[ Reading prompt ] " << std::flush; } + // -1 because our prefill does not generate KVs for the last token. Do not + // just pass abs_pos - 1 because our callback checks pos == abs_pos. + if (abs_pos > 0) --abs_pos; gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env, timing_info); std::cout << "\n\n"; @@ -255,7 +261,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, ThreadingContext ctx(threading); MatMulEnv env(ctx); - if (inference.verbosity >= 2) env.print_best = true; + if (inference.verbosity >= 3) env.print_best = true; const Gemma gemma(loader, inference, ctx); KVCache kv_cache(gemma.Config(), inference, ctx.allocator);