From ba6131311a444d23d6aad2c8ac3acf9e5b406ca4 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 10 Sep 2025 05:32:03 -0700 Subject: [PATCH] Fix gemma_batch_bench for flash attention q_T rows do not change. Also repeat prefill to reflect perf after autotuning. PiperOrigin-RevId: 805319377 --- evals/gemma_batch_bench.cc | 5 +++++ gemma/activations.h | 2 +- util/mat.h | 5 ++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 135c2bb..3ffa858 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -98,6 +98,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) { fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); } + // Run again: prefill will be faster due to autotuning. Fewer decode steps + // because those are already fast. + s_env->SetMaxGeneratedTokens(3); + responses = BatchGemmaReply(inputs); + PROFILER_PRINT_RESULTS(); } } // namespace diff --git a/gemma/activations.h b/gemma/activations.h index 9460d15..f474c84 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -100,7 +100,7 @@ struct AttentionActivations { void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); - q_T.OverrideRows(batch_size); + // q_T rows are always qkv_dim! pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); diff --git a/util/mat.h b/util/mat.h index 4360b69..c2427e5 100644 --- a/util/mat.h +++ b/util/mat.h @@ -186,7 +186,10 @@ class MatPtr : public IFields { // will return this value. Used to set the actual number of rows for // activations preallocated according to the batch size. void OverrideRows(size_t rows) { - HWY_ASSERT(rows <= private_rows_); + if (HWY_UNLIKELY(rows > private_rows_)) { + HWY_ABORT("%s: rows %zu > private_rows_ %u\n", name_.c_str(), rows, + private_rows_); + } override_rows_ = static_cast(rows); }