From b603425bf30b93dd8a88e4865b6353c471b8af43 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 16 Sep 2025 08:01:21 -0700 Subject: [PATCH] Fix batch inference: dangling reference Also add more detailed asserts/error messages. PiperOrigin-RevId: 807695421 --- evals/benchmark_helper.cc | 26 +++++++++++++++++--------- evals/benchmark_helper.h | 2 +- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index bd53845..abdef50 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -105,8 +105,14 @@ std::vector GemmaEnv::BatchQueryModel( const size_t pos, const int token, float) { HWY_ASSERT(query_index < num_queries); + if (token >= gemma_.Config().vocab_size) { + HWY_ABORT("Token %d >= vocab size %d", token, gemma_.Config().vocab_size); + } std::string token_text; - HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector{token}, &token_text)); + if (!gemma_.Tokenizer().Decode(std::vector{token}, &token_text)) { + HWY_ABORT("Failed to decode token %d, tokenizer bytes %s\n", token, + gemma_.Tokenizer().Serialize().substr(0, 10).c_str()); + } res[query_index].response.append(token_text); HWY_ASSERT(pos == res[query_index].tokens_generated); res[query_index].tokens_generated += 1; @@ -143,17 +149,19 @@ QueryResult GemmaEnv::QueryModel(const std::string& input) { } std::vector GemmaEnv::BatchQueryModel( - const std::vector& inputs) { - std::vector prompt_vector; - prompt_vector.reserve(inputs.size()); + const std::vector& prompt_strings) { + std::vector views; + views.reserve(prompt_strings.size()); - for (auto& input : inputs) { - std::vector prompt = WrapAndTokenize(input); - prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size())); + std::vector> storage; + storage.reserve(prompt_strings.size()); + for (auto& input : prompt_strings) { + storage.push_back(WrapAndTokenize(input)); + views.push_back(PromptTokens(storage.back().data(), storage.back().size())); } - QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size()); - return BatchQueryModel(prompt_span); + QueriesPromptTokens span_of_views(views.data(), views.size()); + return BatchQueryModel(span_of_views); } float GemmaEnv::CrossEntropy(const std::string& input) { diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 81ccde6..75cf0d2 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -89,7 +89,7 @@ class GemmaEnv { // Adds turn structure to input, tokenizes and calls the above overload. QueryResult QueryModel(const std::string& input); std::vector BatchQueryModel( - const std::vector& inputs); + const std::vector& prompt_strings); // Runs inference on the given input and calls the callback for each token. void QueryModel(const std::vector& tokens,