add const restriction for benchmark_helper.cc, and paligemma_helper.cc to remove a few uncessary copies.

PiperOrigin-RevId: 807004597
This commit is contained in:
Charles Zhao 2025-09-14 16:26:55 -07:00 committed by Copybara-Service
parent c9b8479f7d
commit 59db30e209
4 changed files with 12 additions and 16 deletions

View File

@ -137,24 +137,21 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
return res; return res;
} }
QueryResult GemmaEnv::QueryModel(std::string& input) { QueryResult GemmaEnv::QueryModel(const std::string& input) {
const std::vector<int> prompt = WrapAndTokenize(input); const std::vector<int> prompt = WrapAndTokenize(input);
return QueryModel(prompt); return QueryModel(prompt);
} }
std::vector<QueryResult> GemmaEnv::BatchQueryModel( std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const std::vector<std::string>& inputs) { const std::vector<std::string>& inputs) {
std::vector<std::vector<int>> prompts;
prompts.reserve(inputs.size());
for (auto& input : inputs) {
std::string mutable_prompt = input;
prompts.push_back(WrapAndTokenize(mutable_prompt));
}
std::vector<PromptTokens> prompt_vector; std::vector<PromptTokens> prompt_vector;
prompt_vector.reserve(prompts.size()); prompt_vector.reserve(inputs.size());
for (auto& prompt : prompts) {
for (auto& input : inputs) {
std::vector<int> prompt = WrapAndTokenize(input);
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size())); prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
} }
QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size()); QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size());
return BatchQueryModel(prompt_span); return BatchQueryModel(prompt_span);
} }

View File

@ -68,7 +68,7 @@ class GemmaEnv {
return tokens; return tokens;
} }
std::vector<int> WrapAndTokenize(std::string& input) const { std::vector<int> WrapAndTokenize(const std::string& input) const {
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(), return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
gemma_.Config().wrapping, 0, input); gemma_.Config().wrapping, 0, input);
} }
@ -87,7 +87,7 @@ class GemmaEnv {
const QueriesPromptTokens& queries_prompt, const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()); const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
// Adds turn structure to input, tokenizes and calls the above overload. // Adds turn structure to input, tokenizes and calls the above overload.
QueryResult QueryModel(std::string& input); QueryResult QueryModel(const std::string& input);
std::vector<QueryResult> BatchQueryModel( std::vector<QueryResult> BatchQueryModel(
const std::vector<std::string>& inputs); const std::vector<std::string>& inputs);

View File

@ -43,8 +43,7 @@ std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
return true; return true;
}; };
std::string mutable_prompt = prompt_text; std::vector<int> tokens = env_->WrapAndTokenize(prompt_text);
std::vector<int> tokens = env_->WrapAndTokenize(mutable_prompt);
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
RuntimeConfig runtime_config = {.max_generated_tokens = 512, RuntimeConfig runtime_config = {.max_generated_tokens = 512,

View File

@ -52,7 +52,7 @@ class GemmaModel {
// Generates a single example, given a prompt and a callback to stream the // Generates a single example, given a prompt and a callback to stream the
// generated tokens. // generated tokens.
void GenerateEx(std::string prompt, gcpp::StreamFunc stream, void GenerateEx(const std::string& prompt, gcpp::StreamFunc stream,
size_t max_generated_tokens, float temperature, size_t max_generated_tokens, float temperature,
float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) { float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) {
std::vector<int> prompt_tokens = env_.WrapAndTokenize(prompt); std::vector<int> prompt_tokens = env_.WrapAndTokenize(prompt);
@ -75,7 +75,7 @@ class GemmaModel {
} }
// Generates a single example, given a prompt, and returns the result. // Generates a single example, given a prompt, and returns the result.
std::string Generate(std::string prompt, size_t max_generated_tokens, std::string Generate(const std::string& prompt, size_t max_generated_tokens,
float temperature, float /*seed*/, float temperature, float /*seed*/,
const std::vector<std::string>& accept, const std::vector<std::string>& accept,
const std::vector<std::string>& end) { const std::vector<std::string>& end) {
@ -192,7 +192,7 @@ class GemmaModel {
// Generates a response to the given prompt, using the last set image. // Generates a response to the given prompt, using the last set image.
// Uses the prompt_tokens if provided, otherwise tokenizes the prompt string. // Uses the prompt_tokens if provided, otherwise tokenizes the prompt string.
std::pair<std::string, std::vector<int>> GenerateWithImage( std::pair<std::string, std::vector<int>> GenerateWithImage(
std::string prompt, size_t max_generated_tokens, float temperature, const std::string& prompt, size_t max_generated_tokens, float temperature,
float /*seed*/, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) { float /*seed*/, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
if (!image_tokens_) throw std::invalid_argument("No image set."); if (!image_tokens_) throw std::invalid_argument("No image set.");
const gcpp::Gemma& model = *env_.GetGemma(); const gcpp::Gemma& model = *env_.GetGemma();