From b3b4b9f92faeadfed727c0e1f1fadfaf36c32c2e Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 24 Feb 2025 10:21:21 -0800 Subject: [PATCH] With new matmul, much larger batch sizes are advantageous, default to 256. Can still override via command line argument. PiperOrigin-RevId: 730502653 --- gemma/gemma.h | 2 +- util/app.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gemma/gemma.h b/gemma/gemma.h index d1a33a6..d7be609 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -100,7 +100,7 @@ struct RuntimeConfig { // These defaults are overridden by InferenceArgs::CopyTo(*this): // Max tokens per batch during prefill. - size_t prefill_tbatch_size = 32; + size_t prefill_tbatch_size = 256; // Max queries per batch (one token from each) during decode. size_t decode_qbatch_size = 16; diff --git a/util/app.h b/util/app.h index 49a75b5..5c0698f 100644 --- a/util/app.h +++ b/util/app.h @@ -273,7 +273,7 @@ struct InferenceArgs : public ArgsBase { visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, "Maximum number of tokens to generate."); - visitor(prefill_tbatch_size, "prefill_tbatch", size_t{64}, + visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256}, "Prefill: max tokens per batch."); visitor(decode_qbatch_size, "decode_qbatch", size_t{16}, "Decode: max queries per batch.");