Fix batch inference: dangling reference

Also add more detailed asserts/error messages.

PiperOrigin-RevId: 807695421
This commit is contained in:
Jan Wassenberg 2025-09-16 08:01:21 -07:00 committed by Copybara-Service
parent f3bc1c17da
commit b603425bf3
2 changed files with 18 additions and 10 deletions

View File

@ -105,8 +105,14 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const size_t pos,
const int token, float) {
HWY_ASSERT(query_index < num_queries);
if (token >= gemma_.Config().vocab_size) {
HWY_ABORT("Token %d >= vocab size %d", token, gemma_.Config().vocab_size);
}
std::string token_text;
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
if (!gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text)) {
HWY_ABORT("Failed to decode token %d, tokenizer bytes %s\n", token,
gemma_.Tokenizer().Serialize().substr(0, 10).c_str());
}
res[query_index].response.append(token_text);
HWY_ASSERT(pos == res[query_index].tokens_generated);
res[query_index].tokens_generated += 1;
@ -143,17 +149,19 @@ QueryResult GemmaEnv::QueryModel(const std::string& input) {
}
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const std::vector<std::string>& inputs) {
std::vector<PromptTokens> prompt_vector;
prompt_vector.reserve(inputs.size());
const std::vector<std::string>& prompt_strings) {
std::vector<PromptTokens> views;
views.reserve(prompt_strings.size());
for (auto& input : inputs) {
std::vector<int> prompt = WrapAndTokenize(input);
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
std::vector<std::vector<int>> storage;
storage.reserve(prompt_strings.size());
for (auto& input : prompt_strings) {
storage.push_back(WrapAndTokenize(input));
views.push_back(PromptTokens(storage.back().data(), storage.back().size()));
}
QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size());
return BatchQueryModel(prompt_span);
QueriesPromptTokens span_of_views(views.data(), views.size());
return BatchQueryModel(span_of_views);
}
float GemmaEnv::CrossEntropy(const std::string& input) {

View File

@ -89,7 +89,7 @@ class GemmaEnv {
// Adds turn structure to input, tokenizes and calls the above overload.
QueryResult QueryModel(const std::string& input);
std::vector<QueryResult> BatchQueryModel(
const std::vector<std::string>& inputs);
const std::vector<std::string>& prompt_strings);
// Runs inference on the given input and calls the callback for each token.
void QueryModel(const std::vector<int>& tokens,