Improved timing for image tokens

Move to TimingInfo, extra newline before profiler

PiperOrigin-RevId: 881943820
This commit is contained in:
Jan Wassenberg 2026-03-11 04:47:36 -07:00 committed by Copybara-Service
parent 70cb9cf1c2
commit cab77f8dc7
7 changed files with 71 additions and 46 deletions

View File

@ -29,7 +29,6 @@
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
#ifdef _WIN32
#include <Windows.h>
@ -195,17 +194,10 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// Use the existing runtime_config defined earlier in the function.
// RuntimeConfig runtime_config = { ... }; // This was already defined
double image_tokens_start = hwy::platform::Now();
// Pass the populated image object to GenerateImageTokens
model.GenerateImageTokens(runtime_config,
active_conversation->kv_cache->SeqLen(), image,
image_tokens, matmul_env);
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
ss.str("");
ss << "\n\n[ Timing info ] Image token generation took: ";
ss << static_cast<int>(image_tokens_duration * 1000) << " ms\n",
LogDebug(ss.str().c_str());
image_tokens, matmul_env, timing_info);
prompt = WrapAndTokenize(
model.Tokenizer(), model.ChatTemplate(), model_config.wrapping,

View File

@ -605,6 +605,9 @@ static void GenerateT(const ModelConfig& config,
config, runtime_config, weights, activations, qbatch, env, timing_info);
// No-op if the profiler is disabled, but useful to separate prefill and
// generate phases for profiling.
if constexpr (PROFILER_ENABLED) {
fprintf(stderr, "\n");
}
env.ctx.profiler.PrintResults();
hwy::BitSet4096<> non_eos; // indexed by qi
@ -725,25 +728,33 @@ void GenerateBatchT(const ModelConfig& config,
void GenerateImageTokensT(const ModelConfig& config,
const RuntimeConfig& runtime_config, size_t seq_len,
const WeightsPtrs& weights, const Image& image,
ImageTokens& image_tokens, MatMulEnv& env) {
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenImageTokens);
if (config.vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens.");
}
RuntimeConfig prefill_runtime_config = runtime_config;
ImageTokens& image_tokens, MatMulEnv& env,
TimingInfo& timing_info) {
const ModelConfig vit_config = GetVitConfig(config);
const size_t num_tokens = vit_config.max_seq_len;
prefill_runtime_config.prefill_tbatch_size =
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(runtime_config, vit_config, num_tokens,
num_tokens, env.ctx, env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations, env);
timing_info.NotifyImageTokenStart();
{
GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenImageTokens);
if (config.vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens.");
}
RuntimeConfig prefill_runtime_config = runtime_config;
prefill_runtime_config.prefill_tbatch_size =
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(runtime_config, vit_config, num_tokens,
num_tokens, env.ctx, env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations, env);
} // end GCPP_ZONE before we print results.
// No-op if the profiler is disabled. Printing now ensures that the
// `PrintResults` after prefill does not include the image token part.
env.ctx.profiler.PrintResults();
timing_info.NotifyImageTokenDone(num_tokens);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
@ -814,13 +825,13 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
size_t seq_len, const Image& image,
ImageTokens& image_tokens,
MatMulEnv& env) const {
ImageTokens& image_tokens, MatMulEnv& env,
TimingInfo& timing_info) const {
env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config,
seq_len, weights_, image,
image_tokens, env);
image_tokens, env, timing_info);
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}

View File

@ -65,6 +65,21 @@ class ContinuousQBatch : public QBatch {
};
struct TimingInfo {
void NotifyImageTokenStart() { image_tokens_start = hwy::platform::Now(); }
void NotifyImageTokenDone(size_t tokens) {
image_tokens_duration = hwy::platform::Now() - image_tokens_start;
image_tokens = tokens;
if (verbosity >= 1) {
fprintf(stderr,
"\n\n[ Timing info ] Image token generation took: %d ms (%.1f "
"tok/sec)\n",
static_cast<int>(image_tokens_duration * 1E3),
image_tokens / image_tokens_duration);
}
}
// be sure to populate prefill_start before calling NotifyPrefill.
void NotifyPrefill(size_t tokens) {
prefill_duration = hwy::platform::Now() - prefill_start;
@ -87,8 +102,8 @@ struct TimingInfo {
fprintf(stderr,
"\n\n[ Timing info ] Prefill: %d ms for %zu prompt tokens "
"(%.2f tokens / sec); Time to first token: %d ms\n",
static_cast<int>(prefill_duration * 1000), prefill_tokens,
prefill_tok_sec, static_cast<int>(time_to_first_token * 1000));
static_cast<int>(prefill_duration * 1E3), prefill_tokens,
prefill_tok_sec, static_cast<int>(time_to_first_token * 1E3));
}
}
if (HWY_UNLIKELY(verbosity >= 2 && tokens_generated % 1024 == 0)) {
@ -110,20 +125,27 @@ struct TimingInfo {
fprintf(stderr,
"\n[ Timing info ] Generate: %d ms for %zu tokens (%.2f tokens / "
"sec)\n",
static_cast<int>(generate_duration * 1000), tokens_generated,
static_cast<int>(generate_duration * 1E3), tokens_generated,
gen_tok_sec);
}
}
int verbosity = 0;
double prefill_start = 0;
double generate_start = 0;
double prefill_duration = 0;
double image_tokens_start = 0.0;
double image_tokens_duration = 0.0;
size_t image_tokens = 0;
double prefill_start = 0.0;
double prefill_duration = 0.0;
size_t prefill_tokens = 0;
double time_to_first_token = 0;
double generate_duration = 0;
double generate_start = 0.0;
double generate_duration = 0.0;
size_t tokens_generated = 0;
double time_to_first_token = 0.0;
size_t generation_steps = 0;
int verbosity = 0;
};
// After construction, all methods are const and thread-compatible if using
@ -173,7 +195,7 @@ class Gemma {
// Generates the image tokens by running the image encoder ViT.
void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len,
const Image& image, ImageTokens& image_tokens,
MatMulEnv& env) const;
MatMulEnv& env, TimingInfo& timing_info) const;
private:
BlobReader reader_;

View File

@ -99,6 +99,8 @@ void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
size_t prompt_size = 0;
const ModelConfig& config = gemma.Config();
TimingInfo timing_info = {.verbosity = inference.verbosity};
const bool have_image = !inference.image_file.path.empty();
Image image;
const size_t pool_dim = config.vit_config.pool_dim;
@ -117,15 +119,8 @@ void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.verbosity = verbosity,
.use_spinning = args.threading.spin};
double image_tokens_start = hwy::platform::Now();
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
image_tokens, env);
if (verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr,
"\n\n[ Timing info ] Image token generation took: %d ms\n",
static_cast<int>(image_tokens_duration * 1000));
}
image_tokens, env, timing_info);
}
// callback function invoked for each generated token.
@ -188,7 +183,6 @@ void ReplGemma(const GemmaArgs& args, const Gemma& gemma, KVCache& kv_cache,
}
// Set up runtime config.
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.verbosity = inference.verbosity,
.batch_stream_token = batch_stream_token,
.use_spinning = args.threading.spin};

View File

@ -29,7 +29,8 @@ void PaliGemmaHelper::InitVit(const std::string& path) {
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, env_->MutableKVCache().SeqLen(),
image, *image_tokens_, env_->MutableEnv());
image, *image_tokens_, env_->MutableEnv(),
timing_info_);
}
std::string PaliGemmaHelper::GemmaReply(const std::string& prompt_text) const {

View File

@ -3,7 +3,9 @@
#include <memory>
#include <string>
#include "evals/benchmark_helper.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
namespace gcpp {
@ -18,6 +20,7 @@ class PaliGemmaHelper {
private:
std::unique_ptr<ImageTokens> image_tokens_;
GemmaEnv* env_;
TimingInfo timing_info_;
};
} // namespace gcpp

View File

@ -183,7 +183,8 @@ class GemmaModel {
env_.MutableEnv().ctx.allocator, gcpp::MatPadding::kOdd));
gcpp::RuntimeConfig runtime_config = {.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, env_.MutableKVCache().SeqLen(),
c_image, *image_tokens_, env_.MutableEnv());
c_image, *image_tokens_, env_.MutableEnv(),
timing_info_);
}
// Generates a response to the given prompt, using the last set image.
@ -244,6 +245,7 @@ class GemmaModel {
private:
gcpp::GemmaEnv env_;
std::unique_ptr<gcpp::ImageTokens> image_tokens_;
gcpp::TimingInfo timing_info_;
float last_prob_;
};