diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index e9fdafb..bd53845 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -137,24 +137,21 @@ std::vector GemmaEnv::BatchQueryModel( return res; } -QueryResult GemmaEnv::QueryModel(std::string& input) { +QueryResult GemmaEnv::QueryModel(const std::string& input) { const std::vector prompt = WrapAndTokenize(input); return QueryModel(prompt); } std::vector GemmaEnv::BatchQueryModel( const std::vector& inputs) { - std::vector> prompts; - prompts.reserve(inputs.size()); - for (auto& input : inputs) { - std::string mutable_prompt = input; - prompts.push_back(WrapAndTokenize(mutable_prompt)); - } std::vector prompt_vector; - prompt_vector.reserve(prompts.size()); - for (auto& prompt : prompts) { + prompt_vector.reserve(inputs.size()); + + for (auto& input : inputs) { + std::vector prompt = WrapAndTokenize(input); prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size())); } + QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size()); return BatchQueryModel(prompt_span); } diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 261daa4..81ccde6 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -68,7 +68,7 @@ class GemmaEnv { return tokens; } - std::vector WrapAndTokenize(std::string& input) const { + std::vector WrapAndTokenize(const std::string& input) const { return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(), gemma_.Config().wrapping, 0, input); } @@ -87,7 +87,7 @@ class GemmaEnv { const QueriesPromptTokens& queries_prompt, const hwy::Span& prefix_end = hwy::Span()); // Adds turn structure to input, tokenizes and calls the above overload. - QueryResult QueryModel(std::string& input); + QueryResult QueryModel(const std::string& input); std::vector BatchQueryModel( const std::vector& inputs); diff --git a/paligemma/paligemma_helper.cc b/paligemma/paligemma_helper.cc index 449ee00..c32e925 100644 --- a/paligemma/paligemma_helper.cc +++ b/paligemma/paligemma_helper.cc @@ -43,8 +43,7 @@ std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const { return true; }; - std::string mutable_prompt = prompt_text; - std::vector tokens = env_->WrapAndTokenize(mutable_prompt); + std::vector tokens = env_->WrapAndTokenize(prompt_text); tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); RuntimeConfig runtime_config = {.max_generated_tokens = 512, diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 2e39f68..1bab194 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -52,7 +52,7 @@ class GemmaModel { // Generates a single example, given a prompt and a callback to stream the // generated tokens. - void GenerateEx(std::string prompt, gcpp::StreamFunc stream, + void GenerateEx(const std::string& prompt, gcpp::StreamFunc stream, size_t max_generated_tokens, float temperature, float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) { std::vector prompt_tokens = env_.WrapAndTokenize(prompt); @@ -75,7 +75,7 @@ class GemmaModel { } // Generates a single example, given a prompt, and returns the result. - std::string Generate(std::string prompt, size_t max_generated_tokens, + std::string Generate(const std::string& prompt, size_t max_generated_tokens, float temperature, float /*seed*/, const std::vector& accept, const std::vector& end) { @@ -192,7 +192,7 @@ class GemmaModel { // Generates a response to the given prompt, using the last set image. // Uses the prompt_tokens if provided, otherwise tokenizes the prompt string. std::pair> GenerateWithImage( - std::string prompt, size_t max_generated_tokens, float temperature, + const std::string& prompt, size_t max_generated_tokens, float temperature, float /*seed*/, gcpp::AcceptFunc accept, std::vector prompt_tokens) { if (!image_tokens_) throw std::invalid_argument("No image set."); const gcpp::Gemma& model = *env_.GetGemma();