mirror of https://github.com/google/gemma.cpp.git
Fix gemma_batch_bench for flash attention
q_T rows do not change. Also repeat prefill to reflect perf after autotuning. PiperOrigin-RevId: 805319377
This commit is contained in:
parent
9457258330
commit
ba6131311a
|
|
@ -98,6 +98,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||||
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run again: prefill will be faster due to autotuning. Fewer decode steps
|
||||||
|
// because those are already fast.
|
||||||
|
s_env->SetMaxGeneratedTokens(3);
|
||||||
|
responses = BatchGemmaReply(inputs);
|
||||||
|
|
||||||
PROFILER_PRINT_RESULTS();
|
PROFILER_PRINT_RESULTS();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -100,7 +100,7 @@ struct AttentionActivations {
|
||||||
|
|
||||||
void SetBatchSize(size_t batch_size) {
|
void SetBatchSize(size_t batch_size) {
|
||||||
q.OverrideRows(batch_size);
|
q.OverrideRows(batch_size);
|
||||||
q_T.OverrideRows(batch_size);
|
// q_T rows are always qkv_dim!
|
||||||
|
|
||||||
pre_att_rms_out.OverrideRows(batch_size);
|
pre_att_rms_out.OverrideRows(batch_size);
|
||||||
att.OverrideRows(batch_size);
|
att.OverrideRows(batch_size);
|
||||||
|
|
|
||||||
|
|
@ -186,7 +186,10 @@ class MatPtr : public IFields {
|
||||||
// will return this value. Used to set the actual number of rows for
|
// will return this value. Used to set the actual number of rows for
|
||||||
// activations preallocated according to the batch size.
|
// activations preallocated according to the batch size.
|
||||||
void OverrideRows(size_t rows) {
|
void OverrideRows(size_t rows) {
|
||||||
HWY_ASSERT(rows <= private_rows_);
|
if (HWY_UNLIKELY(rows > private_rows_)) {
|
||||||
|
HWY_ABORT("%s: rows %zu > private_rows_ %u\n", name_.c_str(), rows,
|
||||||
|
private_rows_);
|
||||||
|
}
|
||||||
override_rows_ = static_cast<uint32_t>(rows);
|
override_rows_ = static_cast<uint32_t>(rows);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue