diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index abdef50..a495dea 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -78,16 +78,16 @@ QueryResult GemmaEnv::QueryModel(const std::vector& tokens) { << runtime_config_.max_generated_tokens << "\ttemperature: " << runtime_config_.temperature << "\n"; } - gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; + gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity}; runtime_config_.batch_stream_token = batch_stream_token; gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_, timing_info); return result; } -void GemmaEnv::QueryModel( - const std::vector& tokens, const StreamFunc& stream_token) { - gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; +void GemmaEnv::QueryModel(const std::vector& tokens, + const StreamFunc& stream_token) { + gcpp::TimingInfo timing_info{.verbosity = runtime_config_.verbosity}; const StreamFunc previous_stream_token = runtime_config_.stream_token; runtime_config_.stream_token = stream_token; gemma_.Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], env_, @@ -95,7 +95,7 @@ void GemmaEnv::QueryModel( runtime_config_.stream_token = previous_stream_token; } -std::vector GemmaEnv::BatchQueryModel( +QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics( const QueriesPromptTokens& queries_prompt, const hwy::Span& prefix_end) { const size_t num_queries = queries_prompt.size(); @@ -140,7 +140,13 @@ std::vector GemmaEnv::BatchQueryModel( gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end); gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity}; gemma_.GenerateBatch(runtime_config_, all_queries, env_, timing_info); - return res; + return {res, timing_info}; +} + +std::vector GemmaEnv::BatchQueryModel( + const QueriesPromptTokens& queries_prompt, + const hwy::Span& prefix_end) { + return BatchQueryModelWithMetrics(queries_prompt, prefix_end).query_results; } QueryResult GemmaEnv::QueryModel(const std::string& input) { @@ -148,7 +154,7 @@ QueryResult GemmaEnv::QueryModel(const std::string& input) { return QueryModel(prompt); } -std::vector GemmaEnv::BatchQueryModel( +QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics( const std::vector& prompt_strings) { std::vector views; views.reserve(prompt_strings.size()); @@ -161,7 +167,12 @@ std::vector GemmaEnv::BatchQueryModel( } QueriesPromptTokens span_of_views(views.data(), views.size()); - return BatchQueryModel(span_of_views); + return BatchQueryModelWithMetrics(span_of_views); +} + +std::vector GemmaEnv::BatchQueryModel( + const std::vector& inputs) { + return BatchQueryModelWithMetrics(inputs).query_results; } float GemmaEnv::CrossEntropy(const std::string& input) { diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 75cf0d2..2380dbf 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -39,6 +39,14 @@ struct QueryResult { size_t response_start_pos = 0; }; +// Return type for batch query model calls with metrics. +struct QueryResultAndMetrics { + // The query results for each query in the batch. + std::vector query_results; + // The timing information for the batch query. + TimingInfo timing_info; +}; + // Convenience class to load a model and run inference. class GemmaEnv { public: @@ -79,21 +87,30 @@ class GemmaEnv { return string; } + // Adds turn structure to input, tokenizes and calls the below overload. + QueryResult QueryModel(const std::string& input); // Runs inference on the given input and returns the top-1 result string and // the number of tokens that were generated. QueryResult QueryModel(const std::vector& tokens); + // Runs inference on the given input and calls the callback for each token. + void QueryModel(const std::vector& tokens, + const StreamFunc& stream_token); + + // Similar to the above, but runs inference on a batch of inputs. + std::vector BatchQueryModel( + const std::vector& inputs); // The default prefix_end means "causal attention". std::vector BatchQueryModel( const QueriesPromptTokens& queries_prompt, const hwy::Span& prefix_end = hwy::Span()); - // Adds turn structure to input, tokenizes and calls the above overload. - QueryResult QueryModel(const std::string& input); - std::vector BatchQueryModel( - const std::vector& prompt_strings); - // Runs inference on the given input and calls the callback for each token. - void QueryModel(const std::vector& tokens, - const StreamFunc& stream_token); + // Similar to the above, but returns timing information in addition to the + // query results. + QueryResultAndMetrics BatchQueryModelWithMetrics( + const std::vector& prompt_strings); + QueryResultAndMetrics BatchQueryModelWithMetrics( + const QueriesPromptTokens& queries_prompt, + const hwy::Span& prefix_end = hwy::Span()); // Runs inference on the given input and returns the cross entropy, a measure // of how well the model predicts the correct output. It is the average diff --git a/gemma/gemma.h b/gemma/gemma.h index 491999d..5e40bda 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -178,6 +178,7 @@ struct TimingInfo { // be sure to populate prefill_start and generate_start before calling // NotifyGenerated. void NotifyGenerated(size_t batch_size) { + generation_steps += 1; const bool is_first = (tokens_generated == 0); tokens_generated += batch_size; if (HWY_UNLIKELY(is_first)) { @@ -224,6 +225,7 @@ struct TimingInfo { double time_to_first_token = 0; double generate_duration = 0; size_t tokens_generated = 0; + size_t generation_steps = 0; }; // After construction, all methods are const and thread-compatible if using