add const restriction for benchmark_helper.cc, and paligemma_helper.cc to remove a few uncessary copies.

PiperOrigin-RevId: 807004597
This commit is contained in:
Charles Zhao 2025-09-14 16:26:55 -07:00 committed by Copybara-Service
parent c9b8479f7d
commit 59db30e209
4 changed files with 12 additions and 16 deletions

View File

@ -137,24 +137,21 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
return res;
}
QueryResult GemmaEnv::QueryModel(std::string& input) {
QueryResult GemmaEnv::QueryModel(const std::string& input) {
const std::vector<int> prompt = WrapAndTokenize(input);
return QueryModel(prompt);
}
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const std::vector<std::string>& inputs) {
std::vector<std::vector<int>> prompts;
prompts.reserve(inputs.size());
for (auto& input : inputs) {
std::string mutable_prompt = input;
prompts.push_back(WrapAndTokenize(mutable_prompt));
}
std::vector<PromptTokens> prompt_vector;
prompt_vector.reserve(prompts.size());
for (auto& prompt : prompts) {
prompt_vector.reserve(inputs.size());
for (auto& input : inputs) {
std::vector<int> prompt = WrapAndTokenize(input);
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
}
QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size());
return BatchQueryModel(prompt_span);
}

View File

@ -68,7 +68,7 @@ class GemmaEnv {
return tokens;
}
std::vector<int> WrapAndTokenize(std::string& input) const {
std::vector<int> WrapAndTokenize(const std::string& input) const {
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
gemma_.Config().wrapping, 0, input);
}
@ -87,7 +87,7 @@ class GemmaEnv {
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
// Adds turn structure to input, tokenizes and calls the above overload.
QueryResult QueryModel(std::string& input);
QueryResult QueryModel(const std::string& input);
std::vector<QueryResult> BatchQueryModel(
const std::vector<std::string>& inputs);

View File

@ -43,8 +43,7 @@ std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
return true;
};
std::string mutable_prompt = prompt_text;
std::vector<int> tokens = env_->WrapAndTokenize(mutable_prompt);
std::vector<int> tokens = env_->WrapAndTokenize(prompt_text);
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
RuntimeConfig runtime_config = {.max_generated_tokens = 512,

View File

@ -52,7 +52,7 @@ class GemmaModel {
// Generates a single example, given a prompt and a callback to stream the
// generated tokens.
void GenerateEx(std::string prompt, gcpp::StreamFunc stream,
void GenerateEx(const std::string& prompt, gcpp::StreamFunc stream,
size_t max_generated_tokens, float temperature,
float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) {
std::vector<int> prompt_tokens = env_.WrapAndTokenize(prompt);
@ -75,7 +75,7 @@ class GemmaModel {
}
// Generates a single example, given a prompt, and returns the result.
std::string Generate(std::string prompt, size_t max_generated_tokens,
std::string Generate(const std::string& prompt, size_t max_generated_tokens,
float temperature, float /*seed*/,
const std::vector<std::string>& accept,
const std::vector<std::string>& end) {
@ -192,7 +192,7 @@ class GemmaModel {
// Generates a response to the given prompt, using the last set image.
// Uses the prompt_tokens if provided, otherwise tokenizes the prompt string.
std::pair<std::string, std::vector<int>> GenerateWithImage(
std::string prompt, size_t max_generated_tokens, float temperature,
const std::string& prompt, size_t max_generated_tokens, float temperature,
float /*seed*/, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
if (!image_tokens_) throw std::invalid_argument("No image set.");
const gcpp::Gemma& model = *env_.GetGemma();