mirror of https://github.com/google/gemma.cpp.git
Fix gemma_batch_bench for flash attention
q_T rows do not change. Also repeat prefill to reflect perf after autotuning. PiperOrigin-RevId: 805319377
This commit is contained in:
parent
9457258330
commit
ba6131311a
|
|
@ -98,6 +98,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
|||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||
}
|
||||
|
||||
// Run again: prefill will be faster due to autotuning. Fewer decode steps
|
||||
// because those are already fast.
|
||||
s_env->SetMaxGeneratedTokens(3);
|
||||
responses = BatchGemmaReply(inputs);
|
||||
|
||||
PROFILER_PRINT_RESULTS();
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ struct AttentionActivations {
|
|||
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
q.OverrideRows(batch_size);
|
||||
q_T.OverrideRows(batch_size);
|
||||
// q_T rows are always qkv_dim!
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
|
|
|
|||
|
|
@ -186,7 +186,10 @@ class MatPtr : public IFields {
|
|||
// will return this value. Used to set the actual number of rows for
|
||||
// activations preallocated according to the batch size.
|
||||
void OverrideRows(size_t rows) {
|
||||
HWY_ASSERT(rows <= private_rows_);
|
||||
if (HWY_UNLIKELY(rows > private_rows_)) {
|
||||
HWY_ABORT("%s: rows %zu > private_rows_ %u\n", name_.c_str(), rows,
|
||||
private_rows_);
|
||||
}
|
||||
override_rows_ = static_cast<uint32_t>(rows);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue