mirror of https://github.com/google/gemma.cpp.git
Using TimingInfo methods and cleaning up args to DecodeStepT
PiperOrigin-RevId: 725580125
This commit is contained in:
parent
953c877658
commit
64cf6dfe0a
|
|
@ -1227,8 +1227,7 @@ bool DecodeStepT(const ModelWeightsPtrs<T>& weights,
|
||||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||||
const QueriesPos& queries_prefix_end,
|
const QueriesPos& queries_prefix_end,
|
||||||
const hwy::Divisor div_seq_len, const size_t vocab_size,
|
const hwy::Divisor div_seq_len, const size_t vocab_size,
|
||||||
const SampleFunc& sample_token, double prefill_start,
|
const SampleFunc& sample_token, Activations& activations,
|
||||||
double gen_start, Activations& activations,
|
|
||||||
TokenStreamer& token_streamer, std::vector<int>& gen_tokens,
|
TokenStreamer& token_streamer, std::vector<int>& gen_tokens,
|
||||||
TimingInfo& timing_info,
|
TimingInfo& timing_info,
|
||||||
const QueriesMutablePos& queries_mutable_pos) {
|
const QueriesMutablePos& queries_mutable_pos) {
|
||||||
|
|
@ -1255,7 +1254,7 @@ bool DecodeStepT(const ModelWeightsPtrs<T>& weights,
|
||||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||||
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size);
|
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size);
|
||||||
const TokenAndProb tp = sample_token(logits, vocab_size);
|
const TokenAndProb tp = sample_token(logits, vocab_size);
|
||||||
timing_info.NotifyGenerated(prefill_start, gen_start);
|
timing_info.NotifyGenerated();
|
||||||
|
|
||||||
const bool is_eos =
|
const bool is_eos =
|
||||||
token_streamer(query_idx_start + query_idx,
|
token_streamer(query_idx_start + query_idx,
|
||||||
|
|
@ -1318,7 +1317,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
|
|
||||||
// Prefill stops before min_prompt_size - 1 because the last prompt
|
// Prefill stops before min_prompt_size - 1 because the last prompt
|
||||||
// token is the first input token for generation.
|
// token is the first input token for generation.
|
||||||
const double prefill_start = hwy::platform::Now();
|
timing_info.prefill_start = hwy::platform::Now();
|
||||||
// If tbatch is larger than the qbatch we already have in `activations`, then
|
// If tbatch is larger than the qbatch we already have in `activations`, then
|
||||||
// allocate prefill_activations, otherwise reuse.
|
// allocate prefill_activations, otherwise reuse.
|
||||||
const bool use_prefill_activations =
|
const bool use_prefill_activations =
|
||||||
|
|
@ -1337,7 +1336,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
prefilled_tokens += queries_prompt[qi].size() - 1;
|
prefilled_tokens += queries_prompt[qi].size() - 1;
|
||||||
}
|
}
|
||||||
timing_info.NotifyPrefill(prefilled_tokens, prefill_start);
|
timing_info.NotifyPrefill(prefilled_tokens);
|
||||||
// queries_pos are incremented by Prefill.
|
// queries_pos are incremented by Prefill.
|
||||||
|
|
||||||
// Storage for the last generated token from each query, passed to the next
|
// Storage for the last generated token from each query, passed to the next
|
||||||
|
|
@ -1357,16 +1356,16 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
|
|
||||||
{
|
{
|
||||||
const size_t vocab_size = model.Config().vocab_size;
|
const size_t vocab_size = model.Config().vocab_size;
|
||||||
const double gen_start = hwy::platform::Now();
|
timing_info.generate_start = hwy::platform::Now();
|
||||||
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
||||||
bool all_queries_eos = DecodeStepT<T>(
|
bool all_queries_eos = DecodeStepT<T>(
|
||||||
weights, runtime_config, queries_prompt, query_idx_start, kv_caches,
|
weights, runtime_config, queries_prompt, query_idx_start, kv_caches,
|
||||||
queries_prefix_end, div_seq_len, vocab_size, sample_token,
|
queries_prefix_end, div_seq_len, vocab_size, sample_token,
|
||||||
prefill_start, gen_start, activations, token_streamer, gen_tokens,
|
activations, token_streamer, gen_tokens,
|
||||||
timing_info, queries_mutable_pos);
|
timing_info, queries_mutable_pos);
|
||||||
if (all_queries_eos) break;
|
if (all_queries_eos) break;
|
||||||
} // foreach token to generate
|
} // foreach token to generate
|
||||||
timing_info.NotifyGenerateDone(gen_start);
|
timing_info.NotifyGenerateDone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -137,14 +137,17 @@ struct RuntimeConfig {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TimingInfo {
|
struct TimingInfo {
|
||||||
void NotifyPrefill(size_t tokens, double start) {
|
// be sure to populate prefill_start before calling NotifyPrefill.
|
||||||
prefill_duration = hwy::platform::Now() - start;
|
void NotifyPrefill(size_t tokens) {
|
||||||
|
prefill_duration = hwy::platform::Now() - prefill_start;
|
||||||
prefill_tokens = tokens;
|
prefill_tokens = tokens;
|
||||||
time_to_first_token = 0.0;
|
time_to_first_token = 0.0;
|
||||||
tokens_generated = 0;
|
tokens_generated = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NotifyGenerated(double prefill_start, double gen_start) {
|
// be sure to populate prefill_start and generate_start before calling
|
||||||
|
// NotifyGenerated.
|
||||||
|
void NotifyGenerated() {
|
||||||
++tokens_generated;
|
++tokens_generated;
|
||||||
if (HWY_UNLIKELY(tokens_generated == 1)) {
|
if (HWY_UNLIKELY(tokens_generated == 1)) {
|
||||||
time_to_first_token = hwy::platform::Now() - prefill_start;
|
time_to_first_token = hwy::platform::Now() - prefill_start;
|
||||||
|
|
@ -160,7 +163,7 @@ struct TimingInfo {
|
||||||
}
|
}
|
||||||
if (verbosity >= 2 && tokens_generated % 128 == 0) {
|
if (verbosity >= 2 && tokens_generated % 128 == 0) {
|
||||||
double gen_tok_sec = static_cast<double>(tokens_generated) /
|
double gen_tok_sec = static_cast<double>(tokens_generated) /
|
||||||
(hwy::platform::Now() - gen_start);
|
(hwy::platform::Now() - generate_start);
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"\n\n[ Timing info ] %zu tokens generated "
|
"\n\n[ Timing info ] %zu tokens generated "
|
||||||
"(avg speed %.2f tokens / sec)\n\n",
|
"(avg speed %.2f tokens / sec)\n\n",
|
||||||
|
|
@ -168,8 +171,9 @@ struct TimingInfo {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void NotifyGenerateDone(double gen_start) {
|
// be sure to populate generate_start before calling NotifyGenerateDone.
|
||||||
generate_duration = hwy::platform::Now() - gen_start;
|
void NotifyGenerateDone() {
|
||||||
|
generate_duration = hwy::platform::Now() - generate_start;
|
||||||
if (verbosity >= 1) {
|
if (verbosity >= 1) {
|
||||||
double gen_tok_sec =
|
double gen_tok_sec =
|
||||||
static_cast<double>(tokens_generated) / generate_duration;
|
static_cast<double>(tokens_generated) / generate_duration;
|
||||||
|
|
@ -182,6 +186,8 @@ struct TimingInfo {
|
||||||
}
|
}
|
||||||
|
|
||||||
int verbosity = 0;
|
int verbosity = 0;
|
||||||
|
double prefill_start = 0;
|
||||||
|
double generate_start = 0;
|
||||||
double prefill_duration = 0;
|
double prefill_duration = 0;
|
||||||
size_t prefill_tokens = 0;
|
size_t prefill_tokens = 0;
|
||||||
double time_to_first_token = 0;
|
double time_to_first_token = 0;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue