mirror of https://github.com/google/gemma.cpp.git
Minor: batched NotifyGenerate, fix comment/dep
PiperOrigin-RevId: 799889802
This commit is contained in:
parent
86afd53076
commit
5411fd846d
|
|
@ -526,6 +526,7 @@ cc_library(
|
|||
"//io",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
|
|
|
|||
|
|
@ -434,11 +434,12 @@ static void SampleAndStream(
|
|||
MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos,
|
||||
env.ctx);
|
||||
|
||||
timing_info.NotifyGenerated(non_eos.Count());
|
||||
|
||||
// TODO: parallelize
|
||||
non_eos.Foreach([&](size_t qi) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||
timing_info.NotifyGenerated();
|
||||
|
||||
// We streamed all prefill tokens, but pos is still one behind because we
|
||||
// started generation at pos = prompt.size() - 1. We want the pos argument
|
||||
|
|
|
|||
|
|
@ -177,9 +177,10 @@ struct TimingInfo {
|
|||
|
||||
// be sure to populate prefill_start and generate_start before calling
|
||||
// NotifyGenerated.
|
||||
void NotifyGenerated() {
|
||||
++tokens_generated;
|
||||
if (HWY_UNLIKELY(tokens_generated == 1)) {
|
||||
void NotifyGenerated(size_t batch_size) {
|
||||
const bool is_first = (tokens_generated == 0);
|
||||
tokens_generated += batch_size;
|
||||
if (HWY_UNLIKELY(is_first)) {
|
||||
time_to_first_token = hwy::platform::Now() - prefill_start;
|
||||
if (verbosity >= 1) {
|
||||
double prefill_tok_sec =
|
||||
|
|
@ -191,7 +192,7 @@ struct TimingInfo {
|
|||
prefill_tok_sec, static_cast<int>(time_to_first_token * 1000));
|
||||
}
|
||||
}
|
||||
if (verbosity >= 2 && tokens_generated % 128 == 0) {
|
||||
if (HWY_UNLIKELY(verbosity >= 2 && tokens_generated % 1024 == 0)) {
|
||||
double gen_tok_sec = static_cast<double>(tokens_generated) /
|
||||
(hwy::platform::Now() - generate_start);
|
||||
fprintf(stderr,
|
||||
|
|
|
|||
|
|
@ -679,7 +679,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConst(const float c, XT* HWY_RESTRICT x,
|
|||
}
|
||||
}
|
||||
|
||||
// Same as above, but without a separate output. Same as below without the add.
|
||||
// Same as above, but with a separate output. Same as below without the add.
|
||||
template <typename XT, typename OT>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstTo(
|
||||
const float c, const XT* HWY_RESTRICT x, OT* HWY_RESTRICT out,
|
||||
|
|
|
|||
Loading…
Reference in New Issue