This commit is contained in:
copybara-service[bot] 2026-03-11 19:28:44 +00:00 committed by GitHub
commit aafb2a98bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 0 deletions

View File

@ -601,6 +601,10 @@ static void GenerateT(const ModelConfig& config,
SetWeightStats(layer, activations, env.ctx);
}
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: prefill ]\n");
}
const size_t max_gen_steps = PrefillTBatchOrQBatch(
config, runtime_config, weights, activations, qbatch, env, timing_info);
// No-op if the profiler is disabled, but useful to separate prefill and
@ -609,6 +613,10 @@ static void GenerateT(const ModelConfig& config,
fprintf(stderr, "\n");
}
env.ctx.profiler.PrintResults();
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: prefill ]\n");
}
hwy::BitSet4096<> non_eos; // indexed by qi
@ -621,6 +629,10 @@ static void GenerateT(const ModelConfig& config,
const SampleFunc sample_token =
ChooseSampleFunc(runtime_config, engine, env.ctx);
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: generate ]\n");
}
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
Transformer(config, runtime_config, weights, activations, qbatch, env);
@ -628,6 +640,10 @@ static void GenerateT(const ModelConfig& config,
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: generate ]\n");
}
}
// Same as GenerateT, but uses ContinuousQBatch.
@ -733,6 +749,10 @@ void GenerateImageTokensT(const ModelConfig& config,
const ModelConfig vit_config = GetVitConfig(config);
const size_t num_tokens = vit_config.max_seq_len;
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: image_token_gen ]\n");
}
timing_info.NotifyImageTokenStart();
{
@ -755,6 +775,10 @@ void GenerateImageTokensT(const ModelConfig& config,
env.ctx.profiler.PrintResults();
timing_info.NotifyImageTokenDone(num_tokens);
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: image_token_gen ]\n");
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -294,6 +294,7 @@ void Run(const GemmaArgs& args) {
int main(int argc, char** argv) {
gcpp::InternalInit();
int verbosity = 0;
{
// Negligible CPU time.
gcpp::ConsumedArgs consumed(argc, argv);
@ -313,8 +314,17 @@ int main(int argc, char** argv) {
// After `HasHelp` so that we print --help even if unconsumed args remain.
consumed.AbortIfUnconsumed();
verbosity = args.inference.verbosity;
gcpp::Run(args);
}
if (verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: final_stats ]\n");
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
if (verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: final_stats ]\n");
}
return 0;
}