(1) Added QueryResultAndMetrics and BatchQueryModelWithMetrics to also return TimingInfo besides query results.

PiperOrigin-RevId: 810634261
This commit is contained in:
Charles Zhao 2025-09-23 17:01:56 -07:00 committed by Copybara-Service
parent fac8aac4cb
commit 4f0c633248
3 changed files with 45 additions and 15 deletions

View File

@ -78,16 +78,16 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
<< runtime_config_.max_generated_tokens
<< "\ttemperature: " << runtime_config_.temperature << "\n";
}
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity};
runtime_config_.batch_stream_token = batch_stream_token;
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
timing_info);
return result;
}
void GemmaEnv::QueryModel(
const std::vector<int>& tokens, const StreamFunc& stream_token) {
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
void GemmaEnv::QueryModel(const std::vector<int>& tokens,
const StreamFunc& stream_token) {
gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity};
const StreamFunc previous_stream_token = runtime_config_.stream_token;
runtime_config_.stream_token = stream_token;
gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_,
@ -95,7 +95,7 @@ void GemmaEnv::QueryModel(
runtime_config_.stream_token = previous_stream_token;
}
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end) {
const size_t num_queries = queries_prompt.size();
@ -140,7 +140,13 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info);
return res;
return {res, timing_info};
}
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end) {
return BatchQueryModelWithMetrics(queries_prompt, prefix_end).query_results;
}
QueryResult GemmaEnv::QueryModel(const std::string& input) {
@ -148,7 +154,7 @@ QueryResult GemmaEnv::QueryModel(const std::string& input) {
return QueryModel(prompt);
}
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
const std::vector<std::string>& prompt_strings) {
std::vector<PromptTokens> views;
views.reserve(prompt_strings.size());
@ -161,7 +167,12 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
}
QueriesPromptTokens span_of_views(views.data(), views.size());
return BatchQueryModel(span_of_views);
return BatchQueryModelWithMetrics(span_of_views);
}
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const std::vector<std::string>& inputs) {
return BatchQueryModelWithMetrics(inputs).query_results;
}
float GemmaEnv::CrossEntropy(const std::string& input) {

View File

@ -39,6 +39,14 @@ struct QueryResult {
size_t response_start_pos = 0;
};
// Return type for batch query model calls with metrics.
struct QueryResultAndMetrics {
// The query results for each query in the batch.
std::vector<QueryResult> query_results;
// The timing information for the batch query.
TimingInfo timing_info;
};
// Convenience class to load a model and run inference.
class GemmaEnv {
public:
@ -79,21 +87,30 @@ class GemmaEnv {
return string;
}
// Adds turn structure to input, tokenizes and calls the below overload.
QueryResult QueryModel(const std::string& input);
// Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated.
QueryResult QueryModel(const std::vector<int>& tokens);
// Runs inference on the given input and calls the callback for each token.
void QueryModel(const std::vector<int>& tokens,
const StreamFunc& stream_token);
// Similar to the above, but runs inference on a batch of inputs.
std::vector<QueryResult> BatchQueryModel(
const std::vector<std::string>& inputs);
// The default prefix_end means "causal attention".
std::vector<QueryResult> BatchQueryModel(
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
// Adds turn structure to input, tokenizes and calls the above overload.
QueryResult QueryModel(const std::string& input);
std::vector<QueryResult> BatchQueryModel(
const std::vector<std::string>& prompt_strings);
// Runs inference on the given input and calls the callback for each token.
void QueryModel(const std::vector<int>& tokens,
const StreamFunc& stream_token);
// Similar to the above, but returns timing information in addition to the
// query results.
QueryResultAndMetrics BatchQueryModelWithMetrics(
const std::vector<std::string>& prompt_strings);
QueryResultAndMetrics BatchQueryModelWithMetrics(
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
// Runs inference on the given input and returns the cross entropy, a measure
// of how well the model predicts the correct output. It is the average

View File

@ -178,6 +178,7 @@ struct TimingInfo {
// be sure to populate prefill_start and generate_start before calling
// NotifyGenerated.
void NotifyGenerated(size_t batch_size) {
generation_steps += 1;
const bool is_first = (tokens_generated == 0);
tokens_generated += batch_size;
if (HWY_UNLIKELY(is_first)) {
@ -224,6 +225,7 @@ struct TimingInfo {
double time_to_first_token = 0;
double generate_duration = 0;
size_t tokens_generated = 0;
size_t generation_steps = 0;
};
// After construction, all methods are const and thread-compatible if using