From cab77f8dc7286a5ef5702d86abf7bdbcc38b9c0f Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 11 Mar 2026 04:47:36 -0700 Subject: [PATCH] Improved timing for image tokens Move to TimingInfo, extra newline before profiler PiperOrigin-RevId: 881943820 --- gemma/bindings/context.cc | 10 +------- gemma/gemma.cc | 43 ++++++++++++++++++++++------------- gemma/gemma.h | 42 ++++++++++++++++++++++++++-------- gemma/run.cc | 12 +++------- paligemma/paligemma_helper.cc | 3 ++- paligemma/paligemma_helper.h | 3 +++ python/gemma_py.cc | 4 +++- 7 files changed, 71 insertions(+), 46 deletions(-) diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index 5db3adc..d86d460 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -29,7 +29,6 @@ #include "util/threading.h" #include "util/threading_context.h" #include "hwy/profiler.h" -#include "hwy/timer.h" #ifdef _WIN32 #include @@ -195,17 +194,10 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // Use the existing runtime_config defined earlier in the function. // RuntimeConfig runtime_config = { ... }; // This was already defined - double image_tokens_start = hwy::platform::Now(); // Pass the populated image object to GenerateImageTokens model.GenerateImageTokens(runtime_config, active_conversation->kv_cache->SeqLen(), image, - image_tokens, matmul_env); - double image_tokens_duration = hwy::platform::Now() - image_tokens_start; - - ss.str(""); - ss << "\n\n[ Timing info ] Image token generation took: "; - ss << static_cast(image_tokens_duration * 1000) << " ms\n", - LogDebug(ss.str().c_str()); + image_tokens, matmul_env, timing_info); prompt = WrapAndTokenize( model.Tokenizer(), model.ChatTemplate(), model_config.wrapping, diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 72bad8f..4529d96 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -605,6 +605,9 @@ static void GenerateT(const ModelConfig& config, config, runtime_config, weights, activations, qbatch, env, timing_info); // No-op if the profiler is disabled, but useful to separate prefill and // generate phases for profiling. + if constexpr (PROFILER_ENABLED) { + fprintf(stderr, "\n"); + } env.ctx.profiler.PrintResults(); hwy::BitSet4096<> non_eos; // indexed by qi @@ -725,25 +728,33 @@ void GenerateBatchT(const ModelConfig& config, void GenerateImageTokensT(const ModelConfig& config, const RuntimeConfig& runtime_config, size_t seq_len, const WeightsPtrs& weights, const Image& image, - ImageTokens& image_tokens, MatMulEnv& env) { - GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenImageTokens); - if (config.vit_config.layer_configs.empty()) { - HWY_ABORT("Model does not support generating image tokens."); - } - RuntimeConfig prefill_runtime_config = runtime_config; + ImageTokens& image_tokens, MatMulEnv& env, + TimingInfo& timing_info) { const ModelConfig vit_config = GetVitConfig(config); const size_t num_tokens = vit_config.max_seq_len; - prefill_runtime_config.prefill_tbatch_size = - num_tokens / (vit_config.pool_dim * vit_config.pool_dim); - Activations prefill_activations(runtime_config, vit_config, num_tokens, - num_tokens, env.ctx, env.row_ptrs); - // Weights are for the full PaliGemma model, not just the ViT part. - PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, - prefill_activations, env); + + timing_info.NotifyImageTokenStart(); + + { + GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenImageTokens); + if (config.vit_config.layer_configs.empty()) { + HWY_ABORT("Model does not support generating image tokens."); + } + RuntimeConfig prefill_runtime_config = runtime_config; + prefill_runtime_config.prefill_tbatch_size = + num_tokens / (vit_config.pool_dim * vit_config.pool_dim); + Activations prefill_activations(runtime_config, vit_config, num_tokens, + num_tokens, env.ctx, env.row_ptrs); + // Weights are for the full PaliGemma model, not just the ViT part. + PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, + prefill_activations, env); + } // end GCPP_ZONE before we print results. // No-op if the profiler is disabled. Printing now ensures that the // `PrintResults` after prefill does not include the image token part. env.ctx.profiler.PrintResults(); + + timing_info.NotifyImageTokenDone(num_tokens); } // NOLINTNEXTLINE(google-readability-namespace-comments) @@ -814,13 +825,13 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len, const Image& image, - ImageTokens& image_tokens, - MatMulEnv& env) const { + ImageTokens& image_tokens, MatMulEnv& env, + TimingInfo& timing_info) const { env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config, seq_len, weights_, image, - image_tokens, env); + image_tokens, env, timing_info); env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } diff --git a/gemma/gemma.h b/gemma/gemma.h index b630a8c..f04e8b7 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -65,6 +65,21 @@ class ContinuousQBatch : public QBatch { }; struct TimingInfo { + void NotifyImageTokenStart() { image_tokens_start = hwy::platform::Now(); } + + void NotifyImageTokenDone(size_t tokens) { + image_tokens_duration = hwy::platform::Now() - image_tokens_start; + image_tokens = tokens; + + if (verbosity >= 1) { + fprintf(stderr, + "\n\n[ Timing info ] Image token generation took: %d ms (%.1f " + "tok/sec)\n", + static_cast(image_tokens_duration * 1E3), + image_tokens / image_tokens_duration); + } + } + // be sure to populate prefill_start before calling NotifyPrefill. void NotifyPrefill(size_t tokens) { prefill_duration = hwy::platform::Now() - prefill_start; @@ -87,8 +102,8 @@ struct TimingInfo { fprintf(stderr, "\n\n[ Timing info ] Prefill: %d ms for %zu prompt tokens " "(%.2f tokens / sec); Time to first token: %d ms\n", - static_cast(prefill_duration * 1000), prefill_tokens, - prefill_tok_sec, static_cast(time_to_first_token * 1000)); + static_cast(prefill_duration * 1E3), prefill_tokens, + prefill_tok_sec, static_cast(time_to_first_token * 1E3)); } } if (HWY_UNLIKELY(verbosity >= 2 && tokens_generated % 1024 == 0)) { @@ -110,20 +125,27 @@ struct TimingInfo { fprintf(stderr, "\n[ Timing info ] Generate: %d ms for %zu tokens (%.2f tokens / " "sec)\n", - static_cast(generate_duration * 1000), tokens_generated, + static_cast(generate_duration * 1E3), tokens_generated, gen_tok_sec); } } - int verbosity = 0; - double prefill_start = 0; - double generate_start = 0; - double prefill_duration = 0; + double image_tokens_start = 0.0; + double image_tokens_duration = 0.0; + size_t image_tokens = 0; + + double prefill_start = 0.0; + double prefill_duration = 0.0; size_t prefill_tokens = 0; - double time_to_first_token = 0; - double generate_duration = 0; + + double generate_start = 0.0; + double generate_duration = 0.0; size_t tokens_generated = 0; + + double time_to_first_token = 0.0; size_t generation_steps = 0; + + int verbosity = 0; }; // After construction, all methods are const and thread-compatible if using @@ -173,7 +195,7 @@ class Gemma { // Generates the image tokens by running the image encoder ViT. void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len, const Image& image, ImageTokens& image_tokens, - MatMulEnv& env) const; + MatMulEnv& env, TimingInfo& timing_info) const; private: BlobReader reader_; diff --git a/gemma/run.cc b/gemma/run.cc index 95dec0d..4c308c3 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -99,6 +99,8 @@ void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache, size_t prompt_size = 0; const ModelConfig& config = gemma.Config(); + TimingInfo timing_info = {.verbosity = inference.verbosity}; + const bool have_image = !inference.image_file.path.empty(); Image image; const size_t pool_dim = config.vit_config.pool_dim; @@ -117,15 +119,8 @@ void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache, image.Resize(image_size, image_size); RuntimeConfig runtime_config = {.verbosity = verbosity, .use_spinning = args.threading.spin}; - double image_tokens_start = hwy::platform::Now(); gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image, - image_tokens, env); - if (verbosity >= 1) { - double image_tokens_duration = hwy::platform::Now() - image_tokens_start; - fprintf(stderr, - "\n\n[ Timing info ] Image token generation took: %d ms\n", - static_cast(image_tokens_duration * 1000)); - } + image_tokens, env, timing_info); } // callback function invoked for each generated token. @@ -188,7 +183,6 @@ void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache, } // Set up runtime config. - TimingInfo timing_info = {.verbosity = inference.verbosity}; RuntimeConfig runtime_config = {.verbosity = inference.verbosity, .batch_stream_token = batch_stream_token, .use_spinning = args.threading.spin}; diff --git a/paligemma/paligemma_helper.cc b/paligemma/paligemma_helper.cc index c32e925..ed74cd8 100644 --- a/paligemma/paligemma_helper.cc +++ b/paligemma/paligemma_helper.cc @@ -29,7 +29,8 @@ void PaliGemmaHelper::InitVit(const std::string& path) { image.Resize(image_size, image_size); RuntimeConfig runtime_config = {.verbosity = 0}; gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(), - image, *image_tokens_, env_->MutableEnv()); + image, *image_tokens_, env_->MutableEnv(), + timing_info_); } std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const { diff --git a/paligemma/paligemma_helper.h b/paligemma/paligemma_helper.h index 4994c43..b6bf4da 100644 --- a/paligemma/paligemma_helper.h +++ b/paligemma/paligemma_helper.h @@ -3,7 +3,9 @@ #include #include + #include "evals/benchmark_helper.h" +#include "gemma/gemma.h" #include "gemma/gemma_args.h" namespace gcpp { @@ -18,6 +20,7 @@ class PaliGemmaHelper { private: std::unique_ptr image_tokens_; GemmaEnv* env_; + TimingInfo timing_info_; }; } // namespace gcpp diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 0d056d9..ec045b0 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -183,7 +183,8 @@ class GemmaModel { env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd)); gcpp::RuntimeConfig runtime_config = {.verbosity = 0}; gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(), - c_image, *image_tokens_, env_.MutableEnv()); + c_image, *image_tokens_, env_.MutableEnv(), + timing_info_); } // Generates a response to the given prompt, using the last set image. @@ -244,6 +245,7 @@ class GemmaModel { private: gcpp::GemmaEnv env_; std::unique_ptr image_tokens_; + gcpp::TimingInfo timing_info_; float last_prob_; };