mirror of https://github.com/google/gemma.cpp.git
Fix batch inference: dangling reference
Also add more detailed asserts/error messages. PiperOrigin-RevId: 807695421
This commit is contained in:
parent
f3bc1c17da
commit
b603425bf3
|
|
@ -105,8 +105,14 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
const size_t pos,
|
||||
const int token, float) {
|
||||
HWY_ASSERT(query_index < num_queries);
|
||||
if (token >= gemma_.Config().vocab_size) {
|
||||
HWY_ABORT("Token %d >= vocab size %d", token, gemma_.Config().vocab_size);
|
||||
}
|
||||
std::string token_text;
|
||||
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
if (!gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text)) {
|
||||
HWY_ABORT("Failed to decode token %d, tokenizer bytes %s\n", token,
|
||||
gemma_.Tokenizer().Serialize().substr(0, 10).c_str());
|
||||
}
|
||||
res[query_index].response.append(token_text);
|
||||
HWY_ASSERT(pos == res[query_index].tokens_generated);
|
||||
res[query_index].tokens_generated += 1;
|
||||
|
|
@ -143,17 +149,19 @@ QueryResult GemmaEnv::QueryModel(const std::string& input) {
|
|||
}
|
||||
|
||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||
const std::vector<std::string>& inputs) {
|
||||
std::vector<PromptTokens> prompt_vector;
|
||||
prompt_vector.reserve(inputs.size());
|
||||
const std::vector<std::string>& prompt_strings) {
|
||||
std::vector<PromptTokens> views;
|
||||
views.reserve(prompt_strings.size());
|
||||
|
||||
for (auto& input : inputs) {
|
||||
std::vector<int> prompt = WrapAndTokenize(input);
|
||||
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||
std::vector<std::vector<int>> storage;
|
||||
storage.reserve(prompt_strings.size());
|
||||
for (auto& input : prompt_strings) {
|
||||
storage.push_back(WrapAndTokenize(input));
|
||||
views.push_back(PromptTokens(storage.back().data(), storage.back().size()));
|
||||
}
|
||||
|
||||
QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size());
|
||||
return BatchQueryModel(prompt_span);
|
||||
QueriesPromptTokens span_of_views(views.data(), views.size());
|
||||
return BatchQueryModel(span_of_views);
|
||||
}
|
||||
|
||||
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class GemmaEnv {
|
|||
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||
QueryResult QueryModel(const std::string& input);
|
||||
std::vector<QueryResult> BatchQueryModel(
|
||||
const std::vector<std::string>& inputs);
|
||||
const std::vector<std::string>& prompt_strings);
|
||||
|
||||
// Runs inference on the given input and calls the callback for each token.
|
||||
void QueryModel(const std::vector<int>& tokens,
|
||||
|
|
|
|||
Loading…
Reference in New Issue