Minor: batched NotifyGenerate, fix comment/dep

PiperOrigin-RevId: 799889802
This commit is contained in:
Jan Wassenberg 2025-08-26 23:32:43 -07:00 committed by Copybara-Service
parent 86afd53076
commit 5411fd846d
4 changed files with 9 additions and 6 deletions

View File

@ -526,6 +526,7 @@ cc_library(
"//io", "//io",
"//paligemma:image", "//paligemma:image",
"@highway//:hwy", "@highway//:hwy",
"@highway//hwy/contrib/sort:vqsort",
"@highway//:nanobenchmark", # timer "@highway//:nanobenchmark", # timer
"@highway//:profiler", "@highway//:profiler",
"@highway//:thread_pool", "@highway//:thread_pool",

View File

@ -434,11 +434,12 @@ static void SampleAndStream(
MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos, MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos,
env.ctx); env.ctx);
timing_info.NotifyGenerated(non_eos.Count());
// TODO: parallelize // TODO: parallelize
non_eos.Foreach([&](size_t qi) { non_eos.Foreach([&](size_t qi) {
float* HWY_RESTRICT logits = activations.logits.Row(qi); float* HWY_RESTRICT logits = activations.logits.Row(qi);
const TokenAndProb tp = sample_token(logits, config.vocab_size); const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated();
// We streamed all prefill tokens, but pos is still one behind because we // We streamed all prefill tokens, but pos is still one behind because we
// started generation at pos = prompt.size() - 1. We want the pos argument // started generation at pos = prompt.size() - 1. We want the pos argument

View File

@ -177,9 +177,10 @@ struct TimingInfo {
// be sure to populate prefill_start and generate_start before calling // be sure to populate prefill_start and generate_start before calling
// NotifyGenerated. // NotifyGenerated.
void NotifyGenerated() { void NotifyGenerated(size_t batch_size) {
++tokens_generated; const bool is_first = (tokens_generated == 0);
if (HWY_UNLIKELY(tokens_generated == 1)) { tokens_generated += batch_size;
if (HWY_UNLIKELY(is_first)) {
time_to_first_token = hwy::platform::Now() - prefill_start; time_to_first_token = hwy::platform::Now() - prefill_start;
if (verbosity >= 1) { if (verbosity >= 1) {
double prefill_tok_sec = double prefill_tok_sec =
@ -191,7 +192,7 @@ struct TimingInfo {
prefill_tok_sec, static_cast<int>(time_to_first_token * 1000)); prefill_tok_sec, static_cast<int>(time_to_first_token * 1000));
} }
} }
if (verbosity >= 2 && tokens_generated % 128 == 0) { if (HWY_UNLIKELY(verbosity >= 2 && tokens_generated % 1024 == 0)) {
double gen_tok_sec = static_cast<double>(tokens_generated) / double gen_tok_sec = static_cast<double>(tokens_generated) /
(hwy::platform::Now() - generate_start); (hwy::platform::Now() - generate_start);
fprintf(stderr, fprintf(stderr,

View File

@ -679,7 +679,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
} }
} }
// Same as above, but without a separate output. Same as below without the add. // Same as above, but with a separate output. Same as below without the add.
template <typename XT, typename OT> template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out, const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,