diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 135c2bb..3ffa858 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -98,6 +98,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) { fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); } + // Run again: prefill will be faster due to autotuning. Fewer decode steps + // because those are already fast. + s_env->SetMaxGeneratedTokens(3); + responses = BatchGemmaReply(inputs); + PROFILER_PRINT_RESULTS(); } } // namespace diff --git a/gemma/activations.h b/gemma/activations.h index 9460d15..f474c84 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -100,7 +100,7 @@ struct AttentionActivations { void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); - q_T.OverrideRows(batch_size); + // q_T rows are always qkv_dim! pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); diff --git a/util/mat.h b/util/mat.h index 4360b69..c2427e5 100644 --- a/util/mat.h +++ b/util/mat.h @@ -186,7 +186,10 @@ class MatPtr : public IFields { // will return this value. Used to set the actual number of rows for // activations preallocated according to the batch size. void OverrideRows(size_t rows) { - HWY_ASSERT(rows <= private_rows_); + if (HWY_UNLIKELY(rows > private_rows_)) { + HWY_ABORT("%s: rows %zu > private_rows_ %u\n", name_.c_str(), rows, + private_rows_); + } override_rows_ = static_cast(rows); }