mirror of https://github.com/google/gemma.cpp.git
Merge 0e0ae5a910 into cab77f8dc7
This commit is contained in:
commit
aafb2a98bd
|
|
@ -601,6 +601,10 @@ static void GenerateT(const ModelConfig& config,
|
|||
SetWeightStats(layer, activations, env.ctx);
|
||||
}
|
||||
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ BEGIN PHASE: prefill ]\n");
|
||||
}
|
||||
const size_t max_gen_steps = PrefillTBatchOrQBatch(
|
||||
config, runtime_config, weights, activations, qbatch, env, timing_info);
|
||||
// No-op if the profiler is disabled, but useful to separate prefill and
|
||||
|
|
@ -609,6 +613,10 @@ static void GenerateT(const ModelConfig& config,
|
|||
fprintf(stderr, "\n");
|
||||
}
|
||||
env.ctx.profiler.PrintResults();
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ END PHASE: prefill ]\n");
|
||||
}
|
||||
|
||||
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||
|
||||
|
|
@ -621,6 +629,10 @@ static void GenerateT(const ModelConfig& config,
|
|||
const SampleFunc sample_token =
|
||||
ChooseSampleFunc(runtime_config, engine, env.ctx);
|
||||
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ BEGIN PHASE: generate ]\n");
|
||||
}
|
||||
timing_info.generate_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
|
|
@ -628,6 +640,10 @@ static void GenerateT(const ModelConfig& config,
|
|||
qbatch, env, non_eos, timing_info);
|
||||
}
|
||||
timing_info.NotifyGenerateDone();
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ END PHASE: generate ]\n");
|
||||
}
|
||||
}
|
||||
|
||||
// Same as GenerateT, but uses ContinuousQBatch.
|
||||
|
|
@ -733,6 +749,10 @@ void GenerateImageTokensT(const ModelConfig& config,
|
|||
const ModelConfig vit_config = GetVitConfig(config);
|
||||
const size_t num_tokens = vit_config.max_seq_len;
|
||||
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ BEGIN PHASE: image_token_gen ]\n");
|
||||
}
|
||||
timing_info.NotifyImageTokenStart();
|
||||
|
||||
{
|
||||
|
|
@ -755,6 +775,10 @@ void GenerateImageTokensT(const ModelConfig& config,
|
|||
env.ctx.profiler.PrintResults();
|
||||
|
||||
timing_info.NotifyImageTokenDone(num_tokens);
|
||||
if (timing_info.verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ END PHASE: image_token_gen ]\n");
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
10
gemma/run.cc
10
gemma/run.cc
|
|
@ -294,6 +294,7 @@ void Run(const GemmaArgs& args) {
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::InternalInit();
|
||||
int verbosity = 0;
|
||||
{
|
||||
// Negligible CPU time.
|
||||
gcpp::ConsumedArgs consumed(argc, argv);
|
||||
|
|
@ -313,8 +314,17 @@ int main(int argc, char** argv) {
|
|||
// After `HasHelp` so that we print --help even if unconsumed args remain.
|
||||
consumed.AbortIfUnconsumed();
|
||||
|
||||
verbosity = args.inference.verbosity;
|
||||
gcpp::Run(args);
|
||||
}
|
||||
if (verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ BEGIN PHASE: final_stats ]\n");
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
if (verbosity >= 2) {
|
||||
fflush(stdout);
|
||||
fprintf(stderr, "\n[ END PHASE: final_stats ]\n");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue