From 27258b03e67078951408db77cb2732746ea5a627 Mon Sep 17 00:00:00 2001 From: "The gemma.cpp Authors" Date: Wed, 7 Aug 2024 14:15:03 -0700 Subject: [PATCH] Improve performance logging PiperOrigin-RevId: 660534330 --- evals/benchmark_helper.cc | 16 ++---------- gemma/gemma-inl.h | 2 +- gemma/gemma.h | 52 +++++++++++++++++++++++++++++---------- gemma/run.cc | 11 +-------- 4 files changed, 43 insertions(+), 38 deletions(-) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 04a27ed..9298337 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -114,9 +114,6 @@ std::pair GemmaEnv::QueryModel( size_t query_index, size_t pos, int token, float) { ++total_tokens; res += StringFromTokens(std::vector{token}); - if (app_.verbosity >= 1 && total_tokens % 128 == 0) { - LogSpeedStats(time_start, total_tokens); - } return true; }; if (app_.verbosity >= 2) { @@ -125,13 +122,10 @@ std::pair GemmaEnv::QueryModel( << inference_args_.max_generated_tokens << "\ttemperature: " << inference_args_.temperature << "\n"; } - gcpp::TimingInfo timing_info; + gcpp::TimingInfo timing_info { .verbosity = app_.verbosity }; runtime_config_.batch_stream_token = batch_stream_token; model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], timing_info); - if (app_.verbosity >= 1) { - LogSpeedStats(time_start, total_tokens); - } return {res, total_tokens}; } @@ -153,9 +147,6 @@ std::vector> GemmaEnv::BatchQueryModel2( res[query_index].first.append(token_text); res[query_index].second += 1; ++total_tokens; - if (app_.verbosity >= 1 && total_tokens % 128 == 0) { - LogSpeedStats(time_start, total_tokens); - } return true; }; if (app_.verbosity >= 2) { @@ -177,14 +168,11 @@ std::vector> GemmaEnv::BatchQueryModel2( } } - gcpp::TimingInfo timing_info; + gcpp::TimingInfo timing_info = {.verbosity = app_.verbosity}; runtime_config_.batch_stream_token = batch_stream_token; inference_args_.CopyTo(runtime_config_); model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0, KVCaches(&kv_caches_[0], num_queries), timing_info); - if (app_.verbosity >= 1) { - LogSpeedStats(time_start, total_tokens); - } return res; } diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 4d96c2d..6c31356 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -900,7 +900,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, } Softmax(logits, kVocabSize); const int token = sample_token(logits, kVocabSize); - timing_info.NotifyGenerated(prefill_start); + timing_info.NotifyGenerated(prefill_start, gen_start); const bool is_eos = token_streamer(query_idx_start + query_idx, prefill_per_query + 1 + gen_per_query, diff --git a/gemma/gemma.h b/gemma/gemma.h index da4028d..1363e1b 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -32,7 +32,7 @@ #include "hwy/timer.h" // IWYU pragma: end_exports #include "hwy/aligned_allocator.h" // Span -#include "hwy/base.h" // hwy::bfloat16_t +#include "hwy/base.h" // hwy::bfloat16_t namespace gcpp { @@ -82,37 +82,63 @@ struct RuntimeConfig { std::mt19937* gen; StreamFunc stream_token; BatchStreamFunc batch_stream_token; - AcceptFunc accept_token; // if empty, accepts all tokens. - SampleFunc sample_func; // if empty, uses SampleTopK. + AcceptFunc accept_token; // if empty, accepts all tokens. + SampleFunc sample_func; // if empty, uses SampleTopK. LayersOutputFunc layers_output; // if not empty, called after each layer. int eos_id = EOS_ID; }; struct TimingInfo { void NotifyPrefill(size_t tokens, double start) { - prefill_tok_sec = - static_cast(tokens) / (hwy::platform::Now() - start); - gen_tok_sec = 0.0; + prefill_duration = hwy::platform::Now() - start; + prefill_tokens = tokens; time_to_first_token = 0.0; tokens_generated = 0; } - void NotifyGenerated(double prefill_start) { + void NotifyGenerated(double prefill_start, double gen_start) { ++tokens_generated; if (HWY_UNLIKELY(tokens_generated == 1)) { time_to_first_token = hwy::platform::Now() - prefill_start; + if (verbosity >= 1) { + double prefill_tok_sec = + static_cast(prefill_tokens) / prefill_duration; + fprintf(stderr, + "\n\n[ Timing info ] Prefill: %d ms for %zu prompt tokens " + "(%.2f tokens / sec); Time to first token: %d ms\n", + static_cast(prefill_duration * 1000), prefill_tokens, + prefill_tok_sec, static_cast(time_to_first_token * 1000)); + } + } + if (verbosity >= 2 && tokens_generated % 128 == 0) { + double gen_tok_sec = static_cast(tokens_generated) / + (hwy::platform::Now() - gen_start); + fprintf(stderr, + "\n\n[ Timing info ] %zu tokens generated " + "(avg speed %.2f tokens / sec)\n\n", + tokens_generated, gen_tok_sec); } } void NotifyGenerateDone(double gen_start) { - gen_tok_sec = static_cast(tokens_generated) / - (hwy::platform::Now() - gen_start); + generate_duration = hwy::platform::Now() - gen_start; + if (verbosity >= 1) { + double gen_tok_sec = + static_cast(tokens_generated) / generate_duration; + fprintf(stderr, + "\n[ Timing info ] Generate: %d ms for %zu tokens (%.2f tokens / " + "sec)\n", + static_cast(generate_duration * 1000), tokens_generated, + gen_tok_sec); + } } - double prefill_tok_sec; - double gen_tok_sec; - double time_to_first_token; - size_t tokens_generated; + int verbosity = 0; + double prefill_duration = 0; + size_t prefill_tokens = 0; + double time_to_first_token = 0; + double generate_duration = 0; + size_t tokens_generated = 0; }; using PromptTokens = hwy::Span; diff --git a/gemma/run.cc b/gemma/run.cc index 33a9e2b..b519593 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -144,7 +144,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, } } - TimingInfo timing_info; + TimingInfo timing_info = {.verbosity = verbosity}; RuntimeConfig runtime_config = { .verbosity = verbosity, .gen = &gen, @@ -153,15 +153,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, }; args.CopyTo(runtime_config); model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info); - if (verbosity >= 2) { - std::cout << current_pos << " tokens (" << abs_pos << " total tokens)" - << "\n" - << timing_info.prefill_tok_sec << " prefill tokens / sec" - << "\n" - << timing_info.gen_tok_sec << " tokens / sec" << "\n" - << static_cast(timing_info.time_to_first_token * 1000) - << " milliseconds time to first token" << "\n"; - } std::cout << "\n\n"; } std::cout