Minor: batched NotifyGenerate, fix comment/dep

PiperOrigin-RevId: 799889802
This commit is contained in:
Jan Wassenberg 2025-08-26 23:32:43 -07:00 committed by Copybara-Service
parent 86afd53076
commit 5411fd846d
4 changed files with 9 additions and 6 deletions

View File

@ -526,6 +526,7 @@ cc_library(
"//io",
"//paligemma:image",
"@highway//:hwy",
"@highway//hwy/contrib/sort:vqsort",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@highway//:thread_pool",

View File

@ -434,11 +434,12 @@ static void SampleAndStream(
MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos,
env.ctx);
timing_info.NotifyGenerated(non_eos.Count());
// TODO: parallelize
non_eos.Foreach([&](size_t qi) {
float* HWY_RESTRICT logits = activations.logits.Row(qi);
const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated();
// We streamed all prefill tokens, but pos is still one behind because we
// started generation at pos = prompt.size() - 1. We want the pos argument

View File

@ -177,9 +177,10 @@ struct TimingInfo {
// be sure to populate prefill_start and generate_start before calling
// NotifyGenerated.
void NotifyGenerated() {
++tokens_generated;
if (HWY_UNLIKELY(tokens_generated == 1)) {
void NotifyGenerated(size_t batch_size) {
const bool is_first = (tokens_generated == 0);
tokens_generated += batch_size;
if (HWY_UNLIKELY(is_first)) {
time_to_first_token = hwy::platform::Now() - prefill_start;
if (verbosity >= 1) {
double prefill_tok_sec =
@ -191,7 +192,7 @@ struct TimingInfo {
prefill_tok_sec, static_cast<int>(time_to_first_token * 1000));
}
}
if (verbosity >= 2 && tokens_generated % 128 == 0) {
if (HWY_UNLIKELY(verbosity >= 2 && tokens_generated % 1024 == 0)) {
double gen_tok_sec = static_cast<double>(tokens_generated) /
(hwy::platform::Now() - generate_start);
fprintf(stderr,

View File

@ -679,7 +679,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
}
}
// Same as above, but without a separate output. Same as below without the add.
// Same as above, but with a separate output. Same as below without the add.
template <typename XT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,