mirror of https://github.com/google/gemma.cpp.git
add const restriction for benchmark_helper.cc, and paligemma_helper.cc to remove a few uncessary copies.
PiperOrigin-RevId: 807004597
This commit is contained in:
parent
c9b8479f7d
commit
59db30e209
|
|
@ -137,24 +137,21 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
QueryResult GemmaEnv::QueryModel(std::string& input) {
|
QueryResult GemmaEnv::QueryModel(const std::string& input) {
|
||||||
const std::vector<int> prompt = WrapAndTokenize(input);
|
const std::vector<int> prompt = WrapAndTokenize(input);
|
||||||
return QueryModel(prompt);
|
return QueryModel(prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
const std::vector<std::string>& inputs) {
|
const std::vector<std::string>& inputs) {
|
||||||
std::vector<std::vector<int>> prompts;
|
|
||||||
prompts.reserve(inputs.size());
|
|
||||||
for (auto& input : inputs) {
|
|
||||||
std::string mutable_prompt = input;
|
|
||||||
prompts.push_back(WrapAndTokenize(mutable_prompt));
|
|
||||||
}
|
|
||||||
std::vector<PromptTokens> prompt_vector;
|
std::vector<PromptTokens> prompt_vector;
|
||||||
prompt_vector.reserve(prompts.size());
|
prompt_vector.reserve(inputs.size());
|
||||||
for (auto& prompt : prompts) {
|
|
||||||
|
for (auto& input : inputs) {
|
||||||
|
std::vector<int> prompt = WrapAndTokenize(input);
|
||||||
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
|
prompt_vector.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size());
|
QueriesPromptTokens prompt_span(prompt_vector.data(), prompt_vector.size());
|
||||||
return BatchQueryModel(prompt_span);
|
return BatchQueryModel(prompt_span);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ class GemmaEnv {
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> WrapAndTokenize(std::string& input) const {
|
std::vector<int> WrapAndTokenize(const std::string& input) const {
|
||||||
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
return gcpp::WrapAndTokenize(gemma_.Tokenizer(), gemma_.ChatTemplate(),
|
||||||
gemma_.Config().wrapping, 0, input);
|
gemma_.Config().wrapping, 0, input);
|
||||||
}
|
}
|
||||||
|
|
@ -87,7 +87,7 @@ class GemmaEnv {
|
||||||
const QueriesPromptTokens& queries_prompt,
|
const QueriesPromptTokens& queries_prompt,
|
||||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
|
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
|
||||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||||
QueryResult QueryModel(std::string& input);
|
QueryResult QueryModel(const std::string& input);
|
||||||
std::vector<QueryResult> BatchQueryModel(
|
std::vector<QueryResult> BatchQueryModel(
|
||||||
const std::vector<std::string>& inputs);
|
const std::vector<std::string>& inputs);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,8 +43,7 @@ std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string mutable_prompt = prompt_text;
|
std::vector<int> tokens = env_->WrapAndTokenize(prompt_text);
|
||||||
std::vector<int> tokens = env_->WrapAndTokenize(mutable_prompt);
|
|
||||||
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
|
||||||
|
|
||||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ class GemmaModel {
|
||||||
|
|
||||||
// Generates a single example, given a prompt and a callback to stream the
|
// Generates a single example, given a prompt and a callback to stream the
|
||||||
// generated tokens.
|
// generated tokens.
|
||||||
void GenerateEx(std::string prompt, gcpp::StreamFunc stream,
|
void GenerateEx(const std::string& prompt, gcpp::StreamFunc stream,
|
||||||
size_t max_generated_tokens, float temperature,
|
size_t max_generated_tokens, float temperature,
|
||||||
float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) {
|
float /*seed*/, gcpp::AcceptFunc accept, bool skip_prompt) {
|
||||||
std::vector<int> prompt_tokens = env_.WrapAndTokenize(prompt);
|
std::vector<int> prompt_tokens = env_.WrapAndTokenize(prompt);
|
||||||
|
|
@ -75,7 +75,7 @@ class GemmaModel {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates a single example, given a prompt, and returns the result.
|
// Generates a single example, given a prompt, and returns the result.
|
||||||
std::string Generate(std::string prompt, size_t max_generated_tokens,
|
std::string Generate(const std::string& prompt, size_t max_generated_tokens,
|
||||||
float temperature, float /*seed*/,
|
float temperature, float /*seed*/,
|
||||||
const std::vector<std::string>& accept,
|
const std::vector<std::string>& accept,
|
||||||
const std::vector<std::string>& end) {
|
const std::vector<std::string>& end) {
|
||||||
|
|
@ -192,7 +192,7 @@ class GemmaModel {
|
||||||
// Generates a response to the given prompt, using the last set image.
|
// Generates a response to the given prompt, using the last set image.
|
||||||
// Uses the prompt_tokens if provided, otherwise tokenizes the prompt string.
|
// Uses the prompt_tokens if provided, otherwise tokenizes the prompt string.
|
||||||
std::pair<std::string, std::vector<int>> GenerateWithImage(
|
std::pair<std::string, std::vector<int>> GenerateWithImage(
|
||||||
std::string prompt, size_t max_generated_tokens, float temperature,
|
const std::string& prompt, size_t max_generated_tokens, float temperature,
|
||||||
float /*seed*/, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
|
float /*seed*/, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
|
||||||
if (!image_tokens_) throw std::invalid_argument("No image set.");
|
if (!image_tokens_) throw std::invalid_argument("No image set.");
|
||||||
const gcpp::Gemma& model = *env_.GetGemma();
|
const gcpp::Gemma& model = *env_.GetGemma();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue