Increased max_tbatch_size to kMaxBatchSize. Gives 1.5x speed-up for prefill on both intel and AMD machines

Shrank intermediate arrays used in matmul to reduce memory use.

PiperOrigin-RevId: 899579842
This commit is contained in:
Ray Smith 2026-04-14 07:36:19 -07:00 committed by Copybara-Service
parent a29e2fc655
commit 221d8df516
2 changed files with 5 additions and 5 deletions

View File

@ -134,7 +134,7 @@ struct RuntimeConfig {
// These defaults are overridden by InferenceArgs::CopyTo(*this):
// Max tokens per batch during prefill.
size_t prefill_tbatch_size = 256;
size_t prefill_tbatch_size = kMaxBatchSize;
// Max queries per batch (one token from each) during decode.
size_t decode_qbatch_size = 16;
@ -225,7 +225,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
visitor(max_generated_tokens, "max_generated_tokens", size_t{4096},
"Maximum number of tokens to generate.");
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{kMaxBatchSize},
"Prefill: max tokens per batch.");
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
"Decode: max queries per batch.");

View File

@ -54,12 +54,12 @@ HWY_INLINE_VAR constexpr size_t kNR = 4;
HWY_INLINE_VAR constexpr size_t kMaxMR = 4;
// For `MMTilesC`.
HWY_INLINE_VAR constexpr size_t kMaxMC = 512;
HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
HWY_INLINE_VAR constexpr size_t kMaxMC = 256;
HWY_INLINE_VAR constexpr size_t kMaxNC = 6 * 1024;
// Upper bound for per-worker B storage on the stack. Chosen such that one row
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
HWY_INLINE_VAR constexpr size_t kMaxKC = 6 * 1024;
// Policy classes for parallelism, implementing some of `Parallelism`.