mirror of https://github.com/google/gemma.cpp.git
Add/use MaybePrint; also ShowConfig in non-interactive builds
PiperOrigin-RevId: 882688835
This commit is contained in:
parent
197c1a049c
commit
529c201eb6
|
|
@ -122,12 +122,11 @@ QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
if (runtime_config_.verbosity >= 2) {
|
MaybePrint(runtime_config_.verbosity,
|
||||||
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
"Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||||
runtime_config_.max_generated_tokens, runtime_config_.temperature,
|
runtime_config_.max_generated_tokens, runtime_config_.temperature,
|
||||||
runtime_config_.prefill_tbatch_size,
|
runtime_config_.prefill_tbatch_size,
|
||||||
runtime_config_.decode_qbatch_size);
|
runtime_config_.decode_qbatch_size);
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure we have at least one KVCache per query.
|
// Ensure we have at least one KVCache per query.
|
||||||
while (kv_caches_.size() < num_queries) {
|
while (kv_caches_.size() < num_queries) {
|
||||||
|
|
@ -223,8 +222,6 @@ static constexpr const char* CompiledConfig() {
|
||||||
return "tsan";
|
return "tsan";
|
||||||
} else if constexpr (HWY_IS_HWASAN) {
|
} else if constexpr (HWY_IS_HWASAN) {
|
||||||
return "hwasan";
|
return "hwasan";
|
||||||
} else if constexpr (HWY_IS_UBSAN) {
|
|
||||||
return "ubsan";
|
|
||||||
} else if constexpr (HWY_IS_DEBUG_BUILD) {
|
} else if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
return "dbg";
|
return "dbg";
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -245,6 +242,7 @@ void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
|
||||||
WeightsPtrs::ToString(weight_read_mode));
|
WeightsPtrs::ToString(weight_read_mode));
|
||||||
|
|
||||||
if (args.inference.verbosity >= 2) {
|
if (args.inference.verbosity >= 2) {
|
||||||
|
// (fprintf instead of MaybePrint due to local variables)
|
||||||
time_t now = time(nullptr);
|
time_t now = time(nullptr);
|
||||||
char* dt = ctime(&now); // NOLINT
|
char* dt = ctime(&now); // NOLINT
|
||||||
char cpu100[100] = "unknown";
|
char cpu100[100] = "unknown";
|
||||||
|
|
|
||||||
|
|
@ -601,10 +601,7 @@ static void GenerateT(const ModelConfig& config,
|
||||||
SetWeightStats(layer, activations, env.ctx);
|
SetWeightStats(layer, activations, env.ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (timing_info.verbosity >= 2) {
|
MaybePrint(timing_info.verbosity, "[ BEGIN PHASE: prefill ]");
|
||||||
fflush(stdout);
|
|
||||||
fprintf(stderr, "\n[ BEGIN PHASE: prefill ]\n");
|
|
||||||
}
|
|
||||||
const size_t max_gen_steps = PrefillTBatchOrQBatch(
|
const size_t max_gen_steps = PrefillTBatchOrQBatch(
|
||||||
config, runtime_config, weights, activations, qbatch, env, timing_info);
|
config, runtime_config, weights, activations, qbatch, env, timing_info);
|
||||||
// No-op if the profiler is disabled, but useful to separate prefill and
|
// No-op if the profiler is disabled, but useful to separate prefill and
|
||||||
|
|
@ -613,10 +610,6 @@ static void GenerateT(const ModelConfig& config,
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
env.ctx.profiler.PrintResults();
|
env.ctx.profiler.PrintResults();
|
||||||
if (timing_info.verbosity >= 2) {
|
|
||||||
fflush(stdout);
|
|
||||||
fprintf(stderr, "\n[ END PHASE: prefill ]\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
hwy::BitSet4096<> non_eos; // indexed by qi
|
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||||
|
|
||||||
|
|
@ -629,10 +622,8 @@ static void GenerateT(const ModelConfig& config,
|
||||||
const SampleFunc sample_token =
|
const SampleFunc sample_token =
|
||||||
ChooseSampleFunc(runtime_config, engine, env.ctx);
|
ChooseSampleFunc(runtime_config, engine, env.ctx);
|
||||||
|
|
||||||
if (timing_info.verbosity >= 2) {
|
MaybePrint(timing_info.verbosity, "\n[ BEGIN PHASE: generate ]\n");
|
||||||
fflush(stdout);
|
|
||||||
fprintf(stderr, "\n[ BEGIN PHASE: generate ]\n");
|
|
||||||
}
|
|
||||||
timing_info.generate_start = hwy::platform::Now();
|
timing_info.generate_start = hwy::platform::Now();
|
||||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||||
|
|
@ -640,10 +631,6 @@ static void GenerateT(const ModelConfig& config,
|
||||||
qbatch, env, non_eos, timing_info);
|
qbatch, env, non_eos, timing_info);
|
||||||
}
|
}
|
||||||
timing_info.NotifyGenerateDone();
|
timing_info.NotifyGenerateDone();
|
||||||
if (timing_info.verbosity >= 2) {
|
|
||||||
fflush(stdout);
|
|
||||||
fprintf(stderr, "\n[ END PHASE: generate ]\n");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as GenerateT, but uses ContinuousQBatch.
|
// Same as GenerateT, but uses ContinuousQBatch.
|
||||||
|
|
@ -749,10 +736,7 @@ void GenerateImageTokensT(const ModelConfig& config,
|
||||||
const ModelConfig vit_config = GetVitConfig(config);
|
const ModelConfig vit_config = GetVitConfig(config);
|
||||||
const size_t num_tokens = vit_config.max_seq_len;
|
const size_t num_tokens = vit_config.max_seq_len;
|
||||||
|
|
||||||
if (timing_info.verbosity >= 2) {
|
MaybePrint(timing_info.verbosity, "\n[ BEGIN PHASE: image_token_gen ]\n");
|
||||||
fflush(stdout);
|
|
||||||
fprintf(stderr, "\n[ BEGIN PHASE: image_token_gen ]\n");
|
|
||||||
}
|
|
||||||
timing_info.NotifyImageTokenStart();
|
timing_info.NotifyImageTokenStart();
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|
@ -775,10 +759,6 @@ void GenerateImageTokensT(const ModelConfig& config,
|
||||||
env.ctx.profiler.PrintResults();
|
env.ctx.profiler.PrintResults();
|
||||||
|
|
||||||
timing_info.NotifyImageTokenDone(num_tokens);
|
timing_info.NotifyImageTokenDone(num_tokens);
|
||||||
if (timing_info.verbosity >= 2) {
|
|
||||||
fflush(stdout);
|
|
||||||
fprintf(stderr, "\n[ END PHASE: image_token_gen ]\n");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
49
gemma/run.cc
49
gemma/run.cc
|
|
@ -258,31 +258,32 @@ void Run(const GemmaArgs& args) {
|
||||||
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||||
|
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::string instructions =
|
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
|
||||||
"*Usage*\n"
|
|
||||||
" Enter an instruction and press enter (%C resets conversation, "
|
|
||||||
"%Q quits).\n";
|
|
||||||
const std::string multiturn =
|
|
||||||
inference.multiturn == 0
|
|
||||||
? std::string(
|
|
||||||
" Since multiturn is set to 0, conversation will "
|
|
||||||
"automatically reset every turn.\n\n")
|
|
||||||
: "\n";
|
|
||||||
const std::string examples =
|
|
||||||
"*Examples*\n"
|
|
||||||
" - Write an email to grandma thanking her for the cookies.\n"
|
|
||||||
" - What are some historical attractions to visit around "
|
|
||||||
"Massachusetts?\n"
|
|
||||||
" - Compute the nth fibonacci number in javascript.\n"
|
|
||||||
" - Write a standup comedy bit about GPU programming.\n";
|
|
||||||
instructions += multiturn;
|
|
||||||
instructions += examples;
|
|
||||||
|
|
||||||
// Skip the banner and instructions in non-interactive mode
|
// Skip the banner and instructions in non-interactive mode
|
||||||
if (inference.IsInteractive()) {
|
if (inference.IsInteractive()) {
|
||||||
|
std::string instructions =
|
||||||
|
"*Usage*\n"
|
||||||
|
" Enter an instruction and press enter (%C resets conversation, "
|
||||||
|
"%Q quits).\n";
|
||||||
|
const std::string multiturn =
|
||||||
|
inference.multiturn == 0
|
||||||
|
? std::string(
|
||||||
|
" Since multiturn is set to 0, conversation will "
|
||||||
|
"automatically reset every turn.\n\n")
|
||||||
|
: "\n";
|
||||||
|
const std::string examples =
|
||||||
|
"*Examples*\n"
|
||||||
|
" - Write an email to grandma thanking her for the cookies.\n"
|
||||||
|
" - What are some historical attractions to visit around "
|
||||||
|
"Massachusetts?\n"
|
||||||
|
" - Compute the nth fibonacci number in javascript.\n"
|
||||||
|
" - Write a standup comedy bit about GPU programming.\n";
|
||||||
|
instructions += multiturn;
|
||||||
|
instructions += examples;
|
||||||
|
|
||||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||||
<< kAsciiArtBanner << "\n\n";
|
<< kAsciiArtBanner << "\n\n";
|
||||||
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
|
|
||||||
std::cout << "\n" << instructions << "\n";
|
std::cout << "\n" << instructions << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -317,14 +318,6 @@ int main(int argc, char** argv) {
|
||||||
verbosity = args.inference.verbosity;
|
verbosity = args.inference.verbosity;
|
||||||
gcpp::Run(args);
|
gcpp::Run(args);
|
||||||
}
|
}
|
||||||
if (verbosity >= 2) {
|
|
||||||
fflush(stdout);
|
|
||||||
fprintf(stderr, "\n[ BEGIN PHASE: final_stats ]\n");
|
|
||||||
}
|
|
||||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||||
if (verbosity >= 2) {
|
|
||||||
fflush(stdout);
|
|
||||||
fprintf(stderr, "\n[ END PHASE: final_stats ]\n");
|
|
||||||
}
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,10 @@
|
||||||
|
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
|
|
||||||
|
#include <stdarg.h>
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
#include "hwy/contrib/sort/vqsort.h"
|
#include "hwy/contrib/sort/vqsort.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
@ -24,6 +26,16 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
void MaybePrint(int verbosity, const char* format, ...) {
|
||||||
|
char buf[800];
|
||||||
|
va_list args;
|
||||||
|
va_start(args, format);
|
||||||
|
vsnprintf(buf, sizeof(buf), format, args);
|
||||||
|
va_end(args);
|
||||||
|
|
||||||
|
fprintf(stderr, "%s\n", buf); // \n ensures flush.
|
||||||
|
}
|
||||||
|
|
||||||
AesCtrEngine::AesCtrEngine(bool deterministic) {
|
AesCtrEngine::AesCtrEngine(bool deterministic) {
|
||||||
// Pi-based nothing up my sleeve numbers from Randen.
|
// Pi-based nothing up my sleeve numbers from Randen.
|
||||||
key_[0] = 0x243F6A8885A308D3ull;
|
key_[0] = 0x243F6A8885A308D3ull;
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,9 @@ static inline intptr_t MaybeTestInitialized(const void* ptr, size_t size) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If verbosity >= 2, prints the formatted message to stderr.
|
||||||
|
void MaybePrint(int verbosity, const char* format, ...);
|
||||||
|
|
||||||
// Shared between gemma.h and ops-inl.h.
|
// Shared between gemma.h and ops-inl.h.
|
||||||
#pragma pack(push, 1)
|
#pragma pack(push, 1)
|
||||||
struct TokenAndProb {
|
struct TokenAndProb {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue