diff --git a/BUILD.bazel b/BUILD.bazel index 6429f6d..1f9e210 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -526,6 +526,7 @@ cc_library( "//io", "//paligemma:image", "@highway//:hwy", + "@highway//hwy/contrib/sort:vqsort", "@highway//:nanobenchmark", # timer "@highway//:profiler", "@highway//:thread_pool", diff --git a/gemma/gemma.cc b/gemma/gemma.cc index bef3a70..c6da86c 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -434,11 +434,12 @@ static void SampleAndStream( MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos, env.ctx); + timing_info.NotifyGenerated(non_eos.Count()); + // TODO: parallelize non_eos.Foreach([&](size_t qi) { float* HWY_RESTRICT logits = activations.logits.Row(qi); const TokenAndProb tp = sample_token(logits, config.vocab_size); - timing_info.NotifyGenerated(); // We streamed all prefill tokens, but pos is still one behind because we // started generation at pos = prompt.size() - 1. We want the pos argument diff --git a/gemma/gemma.h b/gemma/gemma.h index 5ebd70d..55be003 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -177,9 +177,10 @@ struct TimingInfo { // be sure to populate prefill_start and generate_start before calling // NotifyGenerated. - void NotifyGenerated() { - ++tokens_generated; - if (HWY_UNLIKELY(tokens_generated == 1)) { + void NotifyGenerated(size_t batch_size) { + const bool is_first = (tokens_generated == 0); + tokens_generated += batch_size; + if (HWY_UNLIKELY(is_first)) { time_to_first_token = hwy::platform::Now() - prefill_start; if (verbosity >= 1) { double prefill_tok_sec = @@ -191,7 +192,7 @@ struct TimingInfo { prefill_tok_sec, static_cast(time_to_first_token * 1000)); } } - if (verbosity >= 2 && tokens_generated % 128 == 0) { + if (HWY_UNLIKELY(verbosity >= 2 && tokens_generated % 1024 == 0)) { double gen_tok_sec = static_cast(tokens_generated) / (hwy::platform::Now() - generate_start); fprintf(stderr, diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 95228f4..8a7224b 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -679,7 +679,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x, } } -// Same as above, but without a separate output. Same as below without the add. +// Same as above, but with a separate output. Same as below without the add. template HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo( const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,