Fix pos assertions, refs #665

Ensure the streaming func pos matches the number of calls.
Add two arguments that control pos+1 and pos+=1 behavior.
Also cleanup/add comments.
run: use batch_stream_func, add assert, higher verbosity for MM autotune output
PiperOrigin-RevId: 799511163
This commit is contained in:
Jan Wassenberg 2025-08-26 04:50:06 -07:00 committed by Copybara-Service
parent 9bf0fe4e37
commit ed2f0bd1b0
3 changed files with 73 additions and 62 deletions

View File

@ -138,6 +138,10 @@ TEST_F(GemmaTest, Multiturn) {
// Reset the `response` string here, then check that the model actually has // Reset the `response` string here, then check that the model actually has
// access to the previous turn by asking to reproduce. // access to the previous turn by asking to reproduce.
response.clear(); response.clear();
// -1 because our prefill does not generate KVs for the last token. Do not
// just pass abs_pos - 1 because our callback checks pos == abs_pos.
HWY_ASSERT(abs_pos > 0);
--abs_pos;
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
s_env->MutableEnv(), timing_info); s_env->MutableEnv(), timing_info);
fprintf(stderr, "decoded: '%s'\n", response.c_str()); fprintf(stderr, "decoded: '%s'\n", response.c_str());

View File

@ -127,9 +127,10 @@ static float EmbeddingScaling(size_t model_dim) {
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim)))); hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
} }
// `batch_idx` indicates which row of `x` to write to. // `x_row` indicates which row of `x` to write to.
// `pos` is the *token*'s position, not the start of the batch, because this is // `pos` is the *token*'s position for `AddAbsolutePositionalEmbeddings`, not
// called for batches of tokens in prefill, but batches of queries in decode. // the start of the batch, because this is called for batches of tokens in
// prefill, but batches of queries in decode.
// //
// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3 // For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3
// spec) until we run out of image tokens. This allows for a multi-image prompt // spec) until we run out of image tokens. This allows for a multi-image prompt
@ -137,7 +138,7 @@ static float EmbeddingScaling(size_t model_dim) {
// calling application. // calling application.
// Returns new image_token_position. // Returns new image_token_position.
static HWY_NOINLINE size_t static HWY_NOINLINE size_t
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
const ModelConfig& model_config, const WeightsPtrs& weights, const ModelConfig& model_config, const WeightsPtrs& weights,
MatStorageT<float>& x, ThreadingContext& ctx, MatStorageT<float>& x, ThreadingContext& ctx,
const ImageTokens* image_tokens = nullptr, const ImageTokens* image_tokens = nullptr,
@ -146,14 +147,14 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
if (model_config.wrapping == PromptWrapping::GEMMA_VLM && if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 && image_tokens != nullptr && token == -2 &&
image_token_position < image_tokens->Rows()) { image_token_position < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi), hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(x_row),
x.Cols() * x.ElementBytes()); x.Cols() * x.ElementBytes());
return image_token_position + 1; return image_token_position + 1;
} }
if (model_config.wrapping == PromptWrapping::PALIGEMMA && if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) { image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi), hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(x_row),
x.Cols() * x.ElementBytes()); x.Cols() * x.ElementBytes());
return image_token_position; return image_token_position;
} }
@ -174,14 +175,14 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
const auto embedding_span = const auto embedding_span =
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim); MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi), DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row),
model_dim); model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim, MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim,
ctx.profiler, worker); ctx.profiler, worker);
}); });
if (model_config.absolute_pe) { if (model_config.absolute_pe) {
AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos); AddAbsolutePositionalEmbeddings(x.Row(x_row), model_dim, pos);
} }
return image_token_position; return image_token_position;
} }
@ -249,24 +250,12 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
for (size_t ti = 0; ti < tbatch_size; ++ti) { for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = qbatch_1.Pos(0) + ti; const size_t pos = qbatch_1.Pos(0) + ti;
const size_t pos_in_prompt = tbatch_start + ti; const size_t pos_in_prompt = tbatch_start + ti;
HWY_DASSERT(pos_in_prompt < prompt_size);
const int token = qbatch_1.Prompt(0)[pos_in_prompt]; const int token = qbatch_1.Prompt(0)[pos_in_prompt];
image_token_position = EmbedMMToken( image_token_position = EmbedMMToken(
token, ti, pos, pos_in_prompt, config, weights, activations.x, token, ti, pos, pos_in_prompt, config, weights, activations.x,
env.ctx, runtime_config.image_tokens, image_token_position); env.ctx, runtime_config.image_tokens, image_token_position);
}
// Transformer with one batch of tokens from a single query.
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
++layer_idx) {
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch_1, env);
}
// NOTE: we unconditionally call StreamToken, even if EOS. // NOTE: we unconditionally call StreamToken, even if EOS.
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = qbatch_1.Pos(0) + ti;
const size_t pos_in_prompt = tbatch_start + ti;
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
if (pos_in_prompt < prompt_size - 1) { if (pos_in_prompt < prompt_size - 1) {
runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f); runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f);
} else { } else {
@ -276,6 +265,14 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
} }
} }
// Transformer with one batch of tokens from a single query. No need to
// set `PrevToken` because we already did the embedding above.
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
++layer_idx) {
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch_1, env);
}
qbatch_1.MutablePos(0) += tbatch_size; qbatch_1.MutablePos(0) += tbatch_size;
} // for tbatch_start } // for tbatch_start
if (attend_to_last_token) { if (attend_to_last_token) {
@ -291,8 +288,8 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
} }
// Embeds PrevToken (one from each query) and calls each TransformerLayer. // Embeds PrevToken (one from each query) and calls each TransformerLayer.
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the // Called by query-batched `PrefillQBatch` and `GenerateT`, but not the
// token-batched `PrefillTBatch`. // token-batched `PrefillTBatch`, which supports image embedding.
static HWY_NOINLINE void Transformer(const ModelConfig& config, static HWY_NOINLINE void Transformer(const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const WeightsPtrs& weights, const WeightsPtrs& weights,
@ -324,8 +321,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
} }
} }
// Populates KV cache for the batch queries, one token at a time. Only called // Populates KV cache for the batch queries, one token at a time.
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
const ModelConfig& config, const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
@ -337,6 +333,8 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
for (size_t qi = 0; qi < qbatch.Size(); ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
non_eos.Set(qi); non_eos.Set(qi);
// Should only be called for autoregressive (non-prefix-LM) prefill.
HWY_DASSERT(qbatch.PrefixEnd(qi) == 0); HWY_DASSERT(qbatch.PrefixEnd(qi) == 0);
} }
@ -358,7 +356,7 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
} }
// The input (PrevToken) is one token from each query in the batch. // The input (PrevToken) is one token from each query in the batch.
// Do not call DecodeStepT because it computes logits for token // Do not call `SampleAndStream` because it computes logits for token
// probabilities, which are not required for the prompt tokens. // probabilities, which are not required for the prompt tokens.
Transformer(config, runtime_config, weights, activations, qbatch, env); Transformer(config, runtime_config, weights, activations, qbatch, env);
} }
@ -369,42 +367,40 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
} }
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent // Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the // `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the
// query is at the end of its sequence. // query is at the end of its sequence.
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob, static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
const ModelConfig& config, const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
QBatch& qbatch, hwy::BitSet4096<>& non_eos) { QBatch& qbatch, bool pos_plus_1, bool update_pos,
hwy::BitSet4096<>& non_eos) {
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called. HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi), const size_t pos = qbatch.Pos(qi) + (pos_plus_1 ? 1 : 0);
qbatch.Pos(qi), token, prob))) { if (HWY_UNLIKELY(
!runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) {
// User decided to stop: set token to primary EOS to trigger IsEOS below. // User decided to stop: set token to primary EOS to trigger IsEOS below.
token = config.eos_id; token = config.eos_id;
HWY_DASSERT(config.IsEOS(token)); HWY_DASSERT(config.IsEOS(token));
} }
qbatch.PrevToken(qi) = token; qbatch.PrevToken(qi) = token;
qbatch.MutablePos(qi) += 1; qbatch.MutablePos(qi) += update_pos ? 1 : 0;
// Primary or secondary EOS: mark query as EOS, but still increment (for // Primary or secondary EOS: mark query as EOS, but still increment (for
// multi-turn, we should still keep the prior EOS). // multi-turn, we should still keep the prior EOS).
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi); if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
} }
// For a batch of queries, runs Transformer, computes logits, samples and // Must be called after Transformer: either after prefill, or during decode.
// streams the token. // Computes logits, samples and streams the token.
static void DecodeStepT(const ModelConfig& config, static void SampleAndStream(
const RuntimeConfig& runtime_config, const ModelConfig& config, const RuntimeConfig& runtime_config,
const WeightsPtrs& weights, const WeightsPtrs& weights, const SampleFunc& sample_token,
const SampleFunc& sample_token, Activations& activations, QBatch& qbatch, bool update_pos, MatMulEnv& env,
Activations& activations, QBatch& qbatch, hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
TimingInfo& timing_info) {
HWY_DASSERT(qbatch.Size() == activations.x.Rows()); HWY_DASSERT(qbatch.Size() == activations.x.Rows());
Transformer(config, runtime_config, weights, activations, qbatch, env);
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx); RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
if (HWY_UNLIKELY(runtime_config.activations_observer)) { if (HWY_UNLIKELY(runtime_config.activations_observer)) {
@ -427,8 +423,12 @@ static void DecodeStepT(const ModelConfig& config,
const TokenAndProb tp = sample_token(logits, config.vocab_size); const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated(); timing_info.NotifyGenerated();
// We streamed all prefill tokens, but pos is still one behind because we
// started generation at pos = prompt.size() - 1. We want the pos argument
// to match the number of calls to `StreamToken`, as expected by the caller.
const bool pos_plus_1 = true;
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch, StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
non_eos); pos_plus_1, update_pos, non_eos);
}); });
} }
@ -476,15 +476,16 @@ static void GenerateT(const ModelConfig& config,
const size_t seq_len = qbatch.KV(0).SeqLen(); const size_t seq_len = qbatch.KV(0).SeqLen();
for (size_t qi = 0; qi < qbatch.Size(); ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const PromptTokens& prompt = qbatch.Prompt(qi); const PromptTokens& prompt = qbatch.Prompt(qi);
// Sanity check: prompts should not be empty. Note that multi-turn prompts
// start with <end_of_turn>.
HWY_ASSERT(prompt.size() != 0);
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size()); max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
// Prefill stops before size - 1 because the last prompt token is the // Prefill stops before size - 1 because the last prompt token is the
// first input token for generation. // first input token for generation.
total_prefill_tokens += prompt.size() - 1; total_prefill_tokens += prompt.size() - 1;
// Sanity check: prompts should not be empty, nor start with EOS.
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0; all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0;
// We use a single divisor, so all sequence lengths must be the same. // We use a single divisor, so all sequence lengths must be the same.
@ -518,14 +519,12 @@ static void GenerateT(const ModelConfig& config,
// Stream the last prompt token from each query, fill activations.gen_tokens. // Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
const bool pos_plus_1 = false; // during prefill, pos is still correct.
// In autoregressive mode, we have not prefilled the last token, so do
// not advance.
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
runtime_config, qbatch, non_eos); runtime_config, qbatch, pos_plus_1, update_pos, non_eos);
// StreamAndUpdateEOS() sets the stream position one token too far in
// autoregressive mode.
const bool attend_to_last_token = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
if (!attend_to_last_token) {
qbatch.MutablePos(qi) -= 1;
}
} }
size_t max_gen_steps = runtime_config.max_generated_tokens; size_t max_gen_steps = runtime_config.max_generated_tokens;
@ -540,8 +539,10 @@ static void GenerateT(const ModelConfig& config,
{ {
timing_info.generate_start = hwy::platform::Now(); timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
DecodeStepT(config, runtime_config, weights, sample_token, activations, Transformer(config, runtime_config, weights, activations, qbatch, env);
qbatch, env, non_eos, timing_info); SampleAndStream(config, runtime_config, weights, sample_token,
activations, qbatch, /*update_pos=*/true, env, non_eos,
timing_info);
} }
timing_info.NotifyGenerateDone(); timing_info.NotifyGenerateDone();
} }

View File

@ -132,7 +132,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
} }
// callback function invoked for each generated token. // callback function invoked for each generated token.
auto stream_token = [&](int token, float) { auto batch_stream_token = [&](size_t query_idx, size_t pos, int token,
float) {
std::string token_text;
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
HWY_ASSERT(pos == abs_pos);
++abs_pos; ++abs_pos;
const bool in_prompt = tokens_generated_this_turn < prompt_size; const bool in_prompt = tokens_generated_this_turn < prompt_size;
const bool first_response_token = tokens_generated_this_turn == prompt_size; const bool first_response_token = tokens_generated_this_turn == prompt_size;
@ -148,8 +153,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
} }
return true; return true;
} }
std::string token_text;
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
if (first_response_token) { if (first_response_token) {
token_text.erase(0, token_text.find_first_not_of(" \t\n")); token_text.erase(0, token_text.find_first_not_of(" \t\n"));
if (inference.verbosity >= 1) { if (inference.verbosity >= 1) {
@ -187,7 +190,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
TimingInfo timing_info = {.verbosity = inference.verbosity}; TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.gen = &gen, RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity, .verbosity = inference.verbosity,
.stream_token = stream_token, .batch_stream_token = batch_stream_token,
.use_spinning = threading.spin}; .use_spinning = threading.spin};
inference.CopyTo(runtime_config); inference.CopyTo(runtime_config);
std::vector<int> prompt; std::vector<int> prompt;
@ -223,6 +226,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
if (inference.verbosity >= 1) { if (inference.verbosity >= 1) {
std::cerr << "\n[ Reading prompt ] " << std::flush; std::cerr << "\n[ Reading prompt ] " << std::flush;
} }
// -1 because our prefill does not generate KVs for the last token. Do not
// just pass abs_pos - 1 because our callback checks pos == abs_pos.
if (abs_pos > 0) --abs_pos;
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env, gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env,
timing_info); timing_info);
std::cout << "\n\n"; std::cout << "\n\n";
@ -255,7 +261,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
ThreadingContext ctx(threading); ThreadingContext ctx(threading);
MatMulEnv env(ctx); MatMulEnv env(ctx);
if (inference.verbosity >= 2) env.print_best = true; if (inference.verbosity >= 3) env.print_best = true;
const Gemma gemma(loader, inference, ctx); const Gemma gemma(loader, inference, ctx);
KVCache kv_cache(gemma.Config(), inference, ctx.allocator); KVCache kv_cache(gemma.Config(), inference, ctx.allocator);