Improve performance logging

PiperOrigin-RevId: 660534330
This commit is contained in:
The gemma.cpp Authors 2024-08-07 14:15:03 -07:00 committed by Copybara-Service
parent 4154f5a910
commit 27258b03e6
4 changed files with 43 additions and 38 deletions

View File

@ -114,9 +114,6 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
size_t query_index, size_t pos, int token, float) {
++total_tokens;
res += StringFromTokens(std::vector<int>{token});
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
LogSpeedStats(time_start, total_tokens);
}
return true;
};
if (app_.verbosity >= 2) {
@ -125,13 +122,10 @@ std::pair<std::string, size_t> GemmaEnv::QueryModel(
<< inference_args_.max_generated_tokens
<< "\ttemperature: " << inference_args_.temperature << "\n";
}
gcpp::TimingInfo timing_info;
gcpp::TimingInfo timing_info { .verbosity = app_.verbosity };
runtime_config_.batch_stream_token = batch_stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
timing_info);
if (app_.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens);
}
return {res, total_tokens};
}
@ -153,9 +147,6 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
res[query_index].first.append(token_text);
res[query_index].second += 1;
++total_tokens;
if (app_.verbosity >= 1 && total_tokens % 128 == 0) {
LogSpeedStats(time_start, total_tokens);
}
return true;
};
if (app_.verbosity >= 2) {
@ -177,14 +168,11 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
}
}
gcpp::TimingInfo timing_info;
gcpp::TimingInfo timing_info = {.verbosity = app_.verbosity};
runtime_config_.batch_stream_token = batch_stream_token;
inference_args_.CopyTo(runtime_config_);
model_->GenerateBatch(runtime_config_, prompts, /*start_pos=*/0,
KVCaches(&kv_caches_[0], num_queries), timing_info);
if (app_.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens);
}
return res;
}

View File

@ -900,7 +900,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
}
Softmax(logits, kVocabSize);
const int token = sample_token(logits, kVocabSize);
timing_info.NotifyGenerated(prefill_start);
timing_info.NotifyGenerated(prefill_start, gen_start);
const bool is_eos = token_streamer(query_idx_start + query_idx,
prefill_per_query + 1 + gen_per_query,

View File

@ -32,7 +32,7 @@
#include "hwy/timer.h"
// IWYU pragma: end_exports
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/base.h" // hwy::bfloat16_t
namespace gcpp {
@ -82,37 +82,63 @@ struct RuntimeConfig {
std::mt19937* gen;
StreamFunc stream_token;
BatchStreamFunc batch_stream_token;
AcceptFunc accept_token; // if empty, accepts all tokens.
SampleFunc sample_func; // if empty, uses SampleTopK.
AcceptFunc accept_token; // if empty, accepts all tokens.
SampleFunc sample_func; // if empty, uses SampleTopK.
LayersOutputFunc layers_output; // if not empty, called after each layer.
int eos_id = EOS_ID;
};
struct TimingInfo {
void NotifyPrefill(size_t tokens, double start) {
prefill_tok_sec =
static_cast<double>(tokens) / (hwy::platform::Now() - start);
gen_tok_sec = 0.0;
prefill_duration = hwy::platform::Now() - start;
prefill_tokens = tokens;
time_to_first_token = 0.0;
tokens_generated = 0;
}
void NotifyGenerated(double prefill_start) {
void NotifyGenerated(double prefill_start, double gen_start) {
++tokens_generated;
if (HWY_UNLIKELY(tokens_generated == 1)) {
time_to_first_token = hwy::platform::Now() - prefill_start;
if (verbosity >= 1) {
double prefill_tok_sec =
static_cast<double>(prefill_tokens) / prefill_duration;
fprintf(stderr,
"\n\n[ Timing info ] Prefill: %d ms for %zu prompt tokens "
"(%.2f tokens / sec); Time to first token: %d ms\n",
static_cast<int>(prefill_duration * 1000), prefill_tokens,
prefill_tok_sec, static_cast<int>(time_to_first_token * 1000));
}
}
if (verbosity >= 2 && tokens_generated % 128 == 0) {
double gen_tok_sec = static_cast<double>(tokens_generated) /
(hwy::platform::Now() - gen_start);
fprintf(stderr,
"\n\n[ Timing info ] %zu tokens generated "
"(avg speed %.2f tokens / sec)\n\n",
tokens_generated, gen_tok_sec);
}
}
void NotifyGenerateDone(double gen_start) {
gen_tok_sec = static_cast<double>(tokens_generated) /
(hwy::platform::Now() - gen_start);
generate_duration = hwy::platform::Now() - gen_start;
if (verbosity >= 1) {
double gen_tok_sec =
static_cast<double>(tokens_generated) / generate_duration;
fprintf(stderr,
"\n[ Timing info ] Generate: %d ms for %zu tokens (%.2f tokens / "
"sec)\n",
static_cast<int>(generate_duration * 1000), tokens_generated,
gen_tok_sec);
}
}
double prefill_tok_sec;
double gen_tok_sec;
double time_to_first_token;
size_t tokens_generated;
int verbosity = 0;
double prefill_duration = 0;
size_t prefill_tokens = 0;
double time_to_first_token = 0;
double generate_duration = 0;
size_t tokens_generated = 0;
};
using PromptTokens = hwy::Span<const int>;

View File

@ -144,7 +144,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
}
}
TimingInfo timing_info;
TimingInfo timing_info = {.verbosity = verbosity};
RuntimeConfig runtime_config = {
.verbosity = verbosity,
.gen = &gen,
@ -153,15 +153,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
};
args.CopyTo(runtime_config);
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
if (verbosity >= 2) {
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
<< "\n"
<< timing_info.prefill_tok_sec << " prefill tokens / sec"
<< "\n"
<< timing_info.gen_tok_sec << " tokens / sec" << "\n"
<< static_cast<int>(timing_info.time_to_first_token * 1000)
<< " milliseconds time to first token" << "\n";
}
std::cout << "\n\n";
}
std::cout