diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 4f8c6e4..76c2b1e 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1227,8 +1227,7 @@ bool DecodeStepT(const ModelWeightsPtrs& weights, const size_t query_idx_start, const KVCaches& kv_caches, const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len, const size_t vocab_size, - const SampleFunc& sample_token, double prefill_start, - double gen_start, Activations& activations, + const SampleFunc& sample_token, Activations& activations, TokenStreamer& token_streamer, std::vector& gen_tokens, TimingInfo& timing_info, const QueriesMutablePos& queries_mutable_pos) { @@ -1255,7 +1254,7 @@ bool DecodeStepT(const ModelWeightsPtrs& weights, float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size); const TokenAndProb tp = sample_token(logits, vocab_size); - timing_info.NotifyGenerated(prefill_start, gen_start); + timing_info.NotifyGenerated(); const bool is_eos = token_streamer(query_idx_start + query_idx, @@ -1318,7 +1317,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, // Prefill stops before min_prompt_size - 1 because the last prompt // token is the first input token for generation. - const double prefill_start = hwy::platform::Now(); + timing_info.prefill_start = hwy::platform::Now(); // If tbatch is larger than the qbatch we already have in `activations`, then // allocate prefill_activations, otherwise reuse. const bool use_prefill_activations = @@ -1337,7 +1336,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, for (size_t qi = 0; qi < num_queries; ++qi) { prefilled_tokens += queries_prompt[qi].size() - 1; } - timing_info.NotifyPrefill(prefilled_tokens, prefill_start); + timing_info.NotifyPrefill(prefilled_tokens); // queries_pos are incremented by Prefill. // Storage for the last generated token from each query, passed to the next @@ -1357,16 +1356,16 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, { const size_t vocab_size = model.Config().vocab_size; - const double gen_start = hwy::platform::Now(); + timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_generated_tokens; ++gen) { bool all_queries_eos = DecodeStepT( weights, runtime_config, queries_prompt, query_idx_start, kv_caches, queries_prefix_end, div_seq_len, vocab_size, sample_token, - prefill_start, gen_start, activations, token_streamer, gen_tokens, + activations, token_streamer, gen_tokens, timing_info, queries_mutable_pos); if (all_queries_eos) break; } // foreach token to generate - timing_info.NotifyGenerateDone(gen_start); + timing_info.NotifyGenerateDone(); } } diff --git a/gemma/gemma.h b/gemma/gemma.h index d0b0427..fef4b7a 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -137,14 +137,17 @@ struct RuntimeConfig { }; struct TimingInfo { - void NotifyPrefill(size_t tokens, double start) { - prefill_duration = hwy::platform::Now() - start; + // be sure to populate prefill_start before calling NotifyPrefill. + void NotifyPrefill(size_t tokens) { + prefill_duration = hwy::platform::Now() - prefill_start; prefill_tokens = tokens; time_to_first_token = 0.0; tokens_generated = 0; } - void NotifyGenerated(double prefill_start, double gen_start) { + // be sure to populate prefill_start and generate_start before calling + // NotifyGenerated. + void NotifyGenerated() { ++tokens_generated; if (HWY_UNLIKELY(tokens_generated == 1)) { time_to_first_token = hwy::platform::Now() - prefill_start; @@ -160,7 +163,7 @@ struct TimingInfo { } if (verbosity >= 2 && tokens_generated % 128 == 0) { double gen_tok_sec = static_cast(tokens_generated) / - (hwy::platform::Now() - gen_start); + (hwy::platform::Now() - generate_start); fprintf(stderr, "\n\n[ Timing info ] %zu tokens generated " "(avg speed %.2f tokens / sec)\n\n", @@ -168,8 +171,9 @@ struct TimingInfo { } } - void NotifyGenerateDone(double gen_start) { - generate_duration = hwy::platform::Now() - gen_start; + // be sure to populate generate_start before calling NotifyGenerateDone. + void NotifyGenerateDone() { + generate_duration = hwy::platform::Now() - generate_start; if (verbosity >= 1) { double gen_tok_sec = static_cast(tokens_generated) / generate_duration; @@ -182,6 +186,8 @@ struct TimingInfo { } int verbosity = 0; + double prefill_start = 0; + double generate_start = 0; double prefill_duration = 0; size_t prefill_tokens = 0; double time_to_first_token = 0;