Add/use MaybePrint; also ShowConfig in non-interactive builds

PiperOrigin-RevId: 882688835
This commit is contained in:
Jan Wassenberg 2026-03-12 11:20:13 -07:00 committed by Copybara-Service
parent 197c1a049c
commit 529c201eb6
5 changed files with 46 additions and 60 deletions

View File

@ -122,12 +122,11 @@ QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
return true;
};
runtime_config_.batch_stream_token = batch_stream_token;
if (runtime_config_.verbosity >= 2) {
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_generated_tokens, runtime_config_.temperature,
runtime_config_.prefill_tbatch_size,
runtime_config_.decode_qbatch_size);
}
MaybePrint(runtime_config_.verbosity,
"Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_generated_tokens, runtime_config_.temperature,
runtime_config_.prefill_tbatch_size,
runtime_config_.decode_qbatch_size);
// Ensure we have at least one KVCache per query.
while (kv_caches_.size() < num_queries) {
@ -223,8 +222,6 @@ static constexpr const char* CompiledConfig() {
return "tsan";
} else if constexpr (HWY_IS_HWASAN) {
return "hwasan";
} else if constexpr (HWY_IS_UBSAN) {
return "ubsan";
} else if constexpr (HWY_IS_DEBUG_BUILD) {
return "dbg";
} else {
@ -245,6 +242,7 @@ void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
WeightsPtrs::ToString(weight_read_mode));
if (args.inference.verbosity >= 2) {
// (fprintf instead of MaybePrint due to local variables)
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
char cpu100[100] = "unknown";

View File

@ -601,10 +601,7 @@ static void GenerateT(const ModelConfig& config,
SetWeightStats(layer, activations, env.ctx);
}
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: prefill ]\n");
}
MaybePrint(timing_info.verbosity, "[ BEGIN PHASE: prefill ]");
const size_t max_gen_steps = PrefillTBatchOrQBatch(
config, runtime_config, weights, activations, qbatch, env, timing_info);
// No-op if the profiler is disabled, but useful to separate prefill and
@ -613,10 +610,6 @@ static void GenerateT(const ModelConfig& config,
fprintf(stderr, "\n");
}
env.ctx.profiler.PrintResults();
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: prefill ]\n");
}
hwy::BitSet4096<> non_eos; // indexed by qi
@ -629,10 +622,8 @@ static void GenerateT(const ModelConfig& config,
const SampleFunc sample_token =
ChooseSampleFunc(runtime_config, engine, env.ctx);
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: generate ]\n");
}
MaybePrint(timing_info.verbosity, "\n[ BEGIN PHASE: generate ]\n");
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
Transformer(config, runtime_config, weights, activations, qbatch, env);
@ -640,10 +631,6 @@ static void GenerateT(const ModelConfig& config,
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: generate ]\n");
}
}
// Same as GenerateT, but uses ContinuousQBatch.
@ -749,10 +736,7 @@ void GenerateImageTokensT(const ModelConfig& config,
const ModelConfig vit_config = GetVitConfig(config);
const size_t num_tokens = vit_config.max_seq_len;
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: image_token_gen ]\n");
}
MaybePrint(timing_info.verbosity, "\n[ BEGIN PHASE: image_token_gen ]\n");
timing_info.NotifyImageTokenStart();
{
@ -775,10 +759,6 @@ void GenerateImageTokensT(const ModelConfig& config,
env.ctx.profiler.PrintResults();
timing_info.NotifyImageTokenDone(num_tokens);
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: image_token_gen ]\n");
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -258,31 +258,32 @@ void Run(const GemmaArgs& args) {
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
if (inference.verbosity >= 1) {
std::string instructions =
"*Usage*\n"
" Enter an instruction and press enter (%C resets conversation, "
"%Q quits).\n";
const std::string multiturn =
inference.multiturn == 0
? std::string(
" Since multiturn is set to 0, conversation will "
"automatically reset every turn.\n\n")
: "\n";
const std::string examples =
"*Examples*\n"
" - Write an email to grandma thanking her for the cookies.\n"
" - What are some historical attractions to visit around "
"Massachusetts?\n"
" - Compute the nth fibonacci number in javascript.\n"
" - Write a standup comedy bit about GPU programming.\n";
instructions += multiturn;
instructions += examples;
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
// Skip the banner and instructions in non-interactive mode
if (inference.IsInteractive()) {
std::string instructions =
"*Usage*\n"
" Enter an instruction and press enter (%C resets conversation, "
"%Q quits).\n";
const std::string multiturn =
inference.multiturn == 0
? std::string(
" Since multiturn is set to 0, conversation will "
"automatically reset every turn.\n\n")
: "\n";
const std::string examples =
"*Examples*\n"
" - Write an email to grandma thanking her for the cookies.\n"
" - What are some historical attractions to visit around "
"Massachusetts?\n"
" - Compute the nth fibonacci number in javascript.\n"
" - Write a standup comedy bit about GPU programming.\n";
instructions += multiturn;
instructions += examples;
std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n";
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
std::cout << "\n" << instructions << "\n";
}
}
@ -317,14 +318,6 @@ int main(int argc, char** argv) {
verbosity = args.inference.verbosity;
gcpp::Run(args);
}
if (verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: final_stats ]\n");
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
if (verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: final_stats ]\n");
}
return 0;
}

View File

@ -15,8 +15,10 @@
#include "util/basics.h"
#include <stdarg.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include "hwy/contrib/sort/vqsort.h"
#include "hwy/highway.h"
@ -24,6 +26,16 @@
namespace gcpp {
void MaybePrint(int verbosity, const char* format, ...) {
char buf[800];
va_list args;
va_start(args, format);
vsnprintf(buf, sizeof(buf), format, args);
va_end(args);
fprintf(stderr, "%s\n", buf); // \n ensures flush.
}
AesCtrEngine::AesCtrEngine(bool deterministic) {
// Pi-based nothing up my sleeve numbers from Randen.
key_[0] = 0x243F6A8885A308D3ull;

View File

@ -81,6 +81,9 @@ static inline intptr_t MaybeTestInitialized(const void* ptr, size_t size) {
#endif
}
// If verbosity >= 2, prints the formatted message to stderr.
void MaybePrint(int verbosity, const char* format, ...);
// Shared between gemma.h and ops-inl.h.
#pragma pack(push, 1)
struct TokenAndProb {