mirror of https://github.com/google/gemma.cpp.git
Add/use MaybePrint; also ShowConfig in non-interactive builds
PiperOrigin-RevId: 882688835
This commit is contained in:
parent
197c1a049c
commit
529c201eb6
|
|
@ -122,12 +122,11 @@ QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
|
|||
return true;
|
||||
};
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
if (runtime_config_.verbosity >= 2) {
|
||||
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||
runtime_config_.max_generated_tokens, runtime_config_.temperature,
|
||||
runtime_config_.prefill_tbatch_size,
|
||||
runtime_config_.decode_qbatch_size);
|
||||
}
|
||||
MaybePrint(runtime_config_.verbosity,
|
||||
"Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||
runtime_config_.max_generated_tokens, runtime_config_.temperature,
|
||||
runtime_config_.prefill_tbatch_size,
|
||||
runtime_config_.decode_qbatch_size);
|
||||
|
||||
// Ensure we have at least one KVCache per query.
|
||||
while (kv_caches_.size() < num_queries) {
|
||||
|
|
@ -223,8 +222,6 @@ static constexpr const char* CompiledConfig() {
|
|||
return "tsan";
|
||||
} else if constexpr (HWY_IS_HWASAN) {
|
||||
return "hwasan";
|
||||
} else if constexpr (HWY_IS_UBSAN) {
|
||||
return "ubsan";
|
||||
} else if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
return "dbg";
|
||||
} else {
|
||||
|
|
@ -245,6 +242,7 @@ void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
|
|||
WeightsPtrs::ToString(weight_read_mode));
|
||||
|
||||
if (args.inference.verbosity >= 2) {
|
||||
// (fprintf instead of MaybePrint due to local variables)
|
||||
time_t now = time(nullptr);
|
||||
char* dt = ctime(&now); // NOLINT
|
||||
char cpu100[100] = "unknown";
|
||||
|
|
|
|||
|
|
@ -601,10 +601,7 @@ static void GenerateT(const ModelConfig& config,
|
|||
SetWeightStats(layer, activations, env.ctx);
|
||||
}
|
||||
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ BEGIN PHASE: prefill ]\n");
|
||||
}
|
||||
MaybePrint(timing_info.verbosity, "[ BEGIN PHASE: prefill ]");
|
||||
const size_t max_gen_steps = PrefillTBatchOrQBatch(
|
||||
config, runtime_config, weights, activations, qbatch, env, timing_info);
|
||||
// No-op if the profiler is disabled, but useful to separate prefill and
|
||||
|
|
@ -613,10 +610,6 @@ static void GenerateT(const ModelConfig& config,
|
|||
fprintf(stderr, "\n");
|
||||
}
|
||||
env.ctx.profiler.PrintResults();
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ END PHASE: prefill ]\n");
|
||||
}
|
||||
|
||||
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||
|
||||
|
|
@ -629,10 +622,8 @@ static void GenerateT(const ModelConfig& config,
|
|||
const SampleFunc sample_token =
|
||||
ChooseSampleFunc(runtime_config, engine, env.ctx);
|
||||
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ BEGIN PHASE: generate ]\n");
|
||||
}
|
||||
MaybePrint(timing_info.verbosity, "\n[ BEGIN PHASE: generate ]\n");
|
||||
|
||||
timing_info.generate_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
|
|
@ -640,10 +631,6 @@ static void GenerateT(const ModelConfig& config,
|
|||
qbatch, env, non_eos, timing_info);
|
||||
}
|
||||
timing_info.NotifyGenerateDone();
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ END PHASE: generate ]\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Same as GenerateT, but uses ContinuousQBatch.
|
||||
|
|
@ -749,10 +736,7 @@ void GenerateImageTokensT(const ModelConfig& config,
|
|||
const ModelConfig vit_config = GetVitConfig(config);
|
||||
const size_t num_tokens = vit_config.max_seq_len;
|
||||
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ BEGIN PHASE: image_token_gen ]\n");
|
||||
}
|
||||
MaybePrint(timing_info.verbosity, "\n[ BEGIN PHASE: image_token_gen ]\n");
|
||||
timing_info.NotifyImageTokenStart();
|
||||
|
||||
{
|
||||
|
|
@ -775,10 +759,6 @@ void GenerateImageTokensT(const ModelConfig& config,
|
|||
env.ctx.profiler.PrintResults();
|
||||
|
||||
timing_info.NotifyImageTokenDone(num_tokens);
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ END PHASE: image_token_gen ]\n");
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
49
gemma/run.cc
49
gemma/run.cc
|
|
@ -258,31 +258,32 @@ void Run(const GemmaArgs& args) {
|
|||
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||
|
||||
if (inference.verbosity >= 1) {
|
||||
std::string instructions =
|
||||
"*Usage*\n"
|
||||
" Enter an instruction and press enter (%C resets conversation, "
|
||||
"%Q quits).\n";
|
||||
const std::string multiturn =
|
||||
inference.multiturn == 0
|
||||
? std::string(
|
||||
" Since multiturn is set to 0, conversation will "
|
||||
"automatically reset every turn.\n\n")
|
||||
: "\n";
|
||||
const std::string examples =
|
||||
"*Examples*\n"
|
||||
" - Write an email to grandma thanking her for the cookies.\n"
|
||||
" - What are some historical attractions to visit around "
|
||||
"Massachusetts?\n"
|
||||
" - Compute the nth fibonacci number in javascript.\n"
|
||||
" - Write a standup comedy bit about GPU programming.\n";
|
||||
instructions += multiturn;
|
||||
instructions += examples;
|
||||
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
|
||||
|
||||
// Skip the banner and instructions in non-interactive mode
|
||||
if (inference.IsInteractive()) {
|
||||
std::string instructions =
|
||||
"*Usage*\n"
|
||||
" Enter an instruction and press enter (%C resets conversation, "
|
||||
"%Q quits).\n";
|
||||
const std::string multiturn =
|
||||
inference.multiturn == 0
|
||||
? std::string(
|
||||
" Since multiturn is set to 0, conversation will "
|
||||
"automatically reset every turn.\n\n")
|
||||
: "\n";
|
||||
const std::string examples =
|
||||
"*Examples*\n"
|
||||
" - Write an email to grandma thanking her for the cookies.\n"
|
||||
" - What are some historical attractions to visit around "
|
||||
"Massachusetts?\n"
|
||||
" - Compute the nth fibonacci number in javascript.\n"
|
||||
" - Write a standup comedy bit about GPU programming.\n";
|
||||
instructions += multiturn;
|
||||
instructions += examples;
|
||||
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
}
|
||||
|
|
@ -317,14 +318,6 @@ int main(int argc, char** argv) {
|
|||
verbosity = args.inference.verbosity;
|
||||
gcpp::Run(args);
|
||||
}
|
||||
if (verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ BEGIN PHASE: final_stats ]\n");
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
if (verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ END PHASE: final_stats ]\n");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,8 +15,10 @@
|
|||
|
||||
#include "util/basics.h"
|
||||
|
||||
#include <stdarg.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "hwy/contrib/sort/vqsort.h"
|
||||
#include "hwy/highway.h"
|
||||
|
|
@ -24,6 +26,16 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
void MaybePrint(int verbosity, const char* format, ...) {
|
||||
char buf[800];
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
vsnprintf(buf, sizeof(buf), format, args);
|
||||
va_end(args);
|
||||
|
||||
fprintf(stderr, "%s\n", buf); // \n ensures flush.
|
||||
}
|
||||
|
||||
AesCtrEngine::AesCtrEngine(bool deterministic) {
|
||||
// Pi-based nothing up my sleeve numbers from Randen.
|
||||
key_[0] = 0x243F6A8885A308D3ull;
|
||||
|
|
|
|||
|
|
@ -81,6 +81,9 @@ static inline intptr_t MaybeTestInitialized(const void* ptr, size_t size) {
|
|||
#endif
|
||||
}
|
||||
|
||||
// If verbosity >= 2, prints the formatted message to stderr.
|
||||
void MaybePrint(int verbosity, const char* format, ...);
|
||||
|
||||
// Shared between gemma.h and ops-inl.h.
|
||||
#pragma pack(push, 1)
|
||||
struct TokenAndProb {
|
||||
|
|
|
|||
Loading…
Reference in New Issue