mirror of https://github.com/google/gemma.cpp.git
Fix batch inference: dangling reference
Also add more detailed asserts/error messages. PiperOrigin-RevId: 807695421
This commit is contained in:
parent
f3bc1c17da
commit
b603425bf3
|
|
@ -105,8 +105,14 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
const size_t pos,
|
const size_t pos,
|
||||||
const int token, float) {
|
const int token, float) {
|
||||||
HWY_ASSERT(query_index < num_queries);
|
HWY_ASSERT(query_index < num_queries);
|
||||||
|
if (token >= gemma_.Config().vocab_size) {
|
||||||
|
HWY_ABORT("Token %d >= vocab size %d", token, gemma_.Config().vocab_size);
|
||||||
|
}
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
if (!gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text)) {
|
||||||
|
HWY_ABORT("Failed to decode token %d, tokenizer bytes %s\n", token,
|
||||||
|
gemma_.Tokenizer().Serialize().substr(0, 10).c_str());
|
||||||
|
}
|
||||||
res[query_index].response.append(token_text);
|
res[query_index].response.append(token_text);
|
||||||
HWY_ASSERT(pos == res[query_index].tokens_generated);
|
HWY_ASSERT(pos == res[query_index].tokens_generated);
|
||||||
res[query_index].tokens_generated += 1;
|
res[query_index].tokens_generated += 1;
|
||||||
|
|
@ -143,17 +149,19 @@ QueryResult GemmaEnv::QueryModel(const std::string& input) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
const std::vector<std::string>& inputs) {
|
const std::vector<std::string>& prompt_strings) {
|
||||||
std::vector<PromptTokens> prompt_vector;
|
std::vector<PromptTokens> views;
|
||||||
prompt_vector.reserve(inputs.size());
|
views.reserve(prompt_strings.size());
|
||||||
|
|
||||||
for (auto& input : inputs) {
|
std::vector<std::vector<int>> storage;
|
||||||
std::vector<int> prompt = WrapAndTokenize(input);
|
storage.reserve(prompt_strings.size());
|
||||||
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
|
for (auto& input : prompt_strings) {
|
||||||
|
storage.push_back(WrapAndTokenize(input));
|
||||||
|
views.push_back(PromptTokens(storage.back().data(), storage.back().size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size());
|
QueriesPromptTokens span_of_views(views.data(), views.size());
|
||||||
return BatchQueryModel(prompt_span);
|
return BatchQueryModel(span_of_views);
|
||||||
}
|
}
|
||||||
|
|
||||||
float GemmaEnv::CrossEntropy(const std::string& input) {
|
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ class GemmaEnv {
|
||||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||||
QueryResult QueryModel(const std::string& input);
|
QueryResult QueryModel(const std::string& input);
|
||||||
std::vector<QueryResult> BatchQueryModel(
|
std::vector<QueryResult> BatchQueryModel(
|
||||||
const std::vector<std::string>& inputs);
|
const std::vector<std::string>& prompt_strings);
|
||||||
|
|
||||||
// Runs inference on the given input and calls the callback for each token.
|
// Runs inference on the given input and calls the callback for each token.
|
||||||
void QueryModel(const std::vector<int>& tokens,
|
void QueryModel(const std::vector<int>& tokens,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue