Fix gemma_batch_bench for flash attention

q_T rows do not change.
Also repeat prefill to reflect perf after autotuning.

PiperOrigin-RevId: 805319377
This commit is contained in:
Jan Wassenberg 2025-09-10 05:32:03 -07:00 committed by Copybara-Service
parent 9457258330
commit ba6131311a
3 changed files with 10 additions and 2 deletions

View File

@ -98,6 +98,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
}
// Run again: prefill will be faster due to autotuning. Fewer decode steps
// because those are already fast.
s_env->SetMaxGeneratedTokens(3);
responses = BatchGemmaReply(inputs);
PROFILER_PRINT_RESULTS();
}
} // namespace

View File

@ -100,7 +100,7 @@ struct AttentionActivations {
void SetBatchSize(size_t batch_size) {
q.OverrideRows(batch_size);
q_T.OverrideRows(batch_size);
// q_T rows are always qkv_dim!
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);

View File

@ -186,7 +186,10 @@ class MatPtr : public IFields {
// will return this value. Used to set the actual number of rows for
// activations preallocated according to the batch size.
void OverrideRows(size_t rows) {
HWY_ASSERT(rows <= private_rows_);
if (HWY_UNLIKELY(rows > private_rows_)) {
HWY_ABORT("%s: rows %zu > private_rows_ %u\n", name_.c_str(), rows,
private_rows_);
}
override_rows_ = static_cast<uint32_t>(rows);
}