mirror of https://github.com/google/gemma.cpp.git
Merge 0e0ae5a910 into cab77f8dc7
This commit is contained in:
commit
aafb2a98bd
|
|
@ -601,6 +601,10 @@ static void GenerateT(const ModelConfig& config,
|
||||||
SetWeightStats(layer, activations, env.ctx);
|
SetWeightStats(layer, activations, env.ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (timing_info.verbosity >= 2) {
|
||||||
|
fflush(stdout);
|
||||||
|
fprintf(stderr, "\n[ BEGIN PHASE: prefill ]\n");
|
||||||
|
}
|
||||||
const size_t max_gen_steps = PrefillTBatchOrQBatch(
|
const size_t max_gen_steps = PrefillTBatchOrQBatch(
|
||||||
config, runtime_config, weights, activations, qbatch, env, timing_info);
|
config, runtime_config, weights, activations, qbatch, env, timing_info);
|
||||||
// No-op if the profiler is disabled, but useful to separate prefill and
|
// No-op if the profiler is disabled, but useful to separate prefill and
|
||||||
|
|
@ -609,6 +613,10 @@ static void GenerateT(const ModelConfig& config,
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
env.ctx.profiler.PrintResults();
|
env.ctx.profiler.PrintResults();
|
||||||
|
if (timing_info.verbosity >= 2) {
|
||||||
|
fflush(stdout);
|
||||||
|
fprintf(stderr, "\n[ END PHASE: prefill ]\n");
|
||||||
|
}
|
||||||
|
|
||||||
hwy::BitSet4096<> non_eos; // indexed by qi
|
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||||
|
|
||||||
|
|
@ -621,6 +629,10 @@ static void GenerateT(const ModelConfig& config,
|
||||||
const SampleFunc sample_token =
|
const SampleFunc sample_token =
|
||||||
ChooseSampleFunc(runtime_config, engine, env.ctx);
|
ChooseSampleFunc(runtime_config, engine, env.ctx);
|
||||||
|
|
||||||
|
if (timing_info.verbosity >= 2) {
|
||||||
|
fflush(stdout);
|
||||||
|
fprintf(stderr, "\n[ BEGIN PHASE: generate ]\n");
|
||||||
|
}
|
||||||
timing_info.generate_start = hwy::platform::Now();
|
timing_info.generate_start = hwy::platform::Now();
|
||||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||||
|
|
@ -628,6 +640,10 @@ static void GenerateT(const ModelConfig& config,
|
||||||
qbatch, env, non_eos, timing_info);
|
qbatch, env, non_eos, timing_info);
|
||||||
}
|
}
|
||||||
timing_info.NotifyGenerateDone();
|
timing_info.NotifyGenerateDone();
|
||||||
|
if (timing_info.verbosity >= 2) {
|
||||||
|
fflush(stdout);
|
||||||
|
fprintf(stderr, "\n[ END PHASE: generate ]\n");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as GenerateT, but uses ContinuousQBatch.
|
// Same as GenerateT, but uses ContinuousQBatch.
|
||||||
|
|
@ -733,6 +749,10 @@ void GenerateImageTokensT(const ModelConfig& config,
|
||||||
const ModelConfig vit_config = GetVitConfig(config);
|
const ModelConfig vit_config = GetVitConfig(config);
|
||||||
const size_t num_tokens = vit_config.max_seq_len;
|
const size_t num_tokens = vit_config.max_seq_len;
|
||||||
|
|
||||||
|
if (timing_info.verbosity >= 2) {
|
||||||
|
fflush(stdout);
|
||||||
|
fprintf(stderr, "\n[ BEGIN PHASE: image_token_gen ]\n");
|
||||||
|
}
|
||||||
timing_info.NotifyImageTokenStart();
|
timing_info.NotifyImageTokenStart();
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|
@ -755,6 +775,10 @@ void GenerateImageTokensT(const ModelConfig& config,
|
||||||
env.ctx.profiler.PrintResults();
|
env.ctx.profiler.PrintResults();
|
||||||
|
|
||||||
timing_info.NotifyImageTokenDone(num_tokens);
|
timing_info.NotifyImageTokenDone(num_tokens);
|
||||||
|
if (timing_info.verbosity >= 2) {
|
||||||
|
fflush(stdout);
|
||||||
|
fprintf(stderr, "\n[ END PHASE: image_token_gen ]\n");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
10
gemma/run.cc
10
gemma/run.cc
|
|
@ -294,6 +294,7 @@ void Run(const GemmaArgs& args) {
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
gcpp::InternalInit();
|
gcpp::InternalInit();
|
||||||
|
int verbosity = 0;
|
||||||
{
|
{
|
||||||
// Negligible CPU time.
|
// Negligible CPU time.
|
||||||
gcpp::ConsumedArgs consumed(argc, argv);
|
gcpp::ConsumedArgs consumed(argc, argv);
|
||||||
|
|
@ -313,8 +314,17 @@ int main(int argc, char** argv) {
|
||||||
// After `HasHelp` so that we print --help even if unconsumed args remain.
|
// After `HasHelp` so that we print --help even if unconsumed args remain.
|
||||||
consumed.AbortIfUnconsumed();
|
consumed.AbortIfUnconsumed();
|
||||||
|
|
||||||
|
verbosity = args.inference.verbosity;
|
||||||
gcpp::Run(args);
|
gcpp::Run(args);
|
||||||
}
|
}
|
||||||
|
if (verbosity >= 2) {
|
||||||
|
fflush(stdout);
|
||||||
|
fprintf(stderr, "\n[ BEGIN PHASE: final_stats ]\n");
|
||||||
|
}
|
||||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||||
|
if (verbosity >= 2) {
|
||||||
|
fflush(stdout);
|
||||||
|
fprintf(stderr, "\n[ END PHASE: final_stats ]\n");
|
||||||
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue