mirror of https://github.com/google/gemma.cpp.git
Increased max_tbatch_size to kMaxBatchSize. Gives 1.5x speed-up for prefill on both intel and AMD machines
Shrank intermediate arrays used in matmul to reduce memory use. PiperOrigin-RevId: 899579842
This commit is contained in:
parent
a29e2fc655
commit
221d8df516
|
|
@ -134,7 +134,7 @@ struct RuntimeConfig {
|
|||
|
||||
// These defaults are overridden by InferenceArgs::CopyTo(*this):
|
||||
// Max tokens per batch during prefill.
|
||||
size_t prefill_tbatch_size = 256;
|
||||
size_t prefill_tbatch_size = kMaxBatchSize;
|
||||
// Max queries per batch (one token from each) during decode.
|
||||
size_t decode_qbatch_size = 16;
|
||||
|
||||
|
|
@ -225,7 +225,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
visitor(max_generated_tokens, "max_generated_tokens", size_t{4096},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
|
||||
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{kMaxBatchSize},
|
||||
"Prefill: max tokens per batch.");
|
||||
visitor(decode_qbatch_size, "decode_qbatch", size_t{16},
|
||||
"Decode: max queries per batch.");
|
||||
|
|
|
|||
|
|
@ -54,12 +54,12 @@ HWY_INLINE_VAR constexpr size_t kNR = 4;
|
|||
HWY_INLINE_VAR constexpr size_t kMaxMR = 4;
|
||||
|
||||
// For `MMTilesC`.
|
||||
HWY_INLINE_VAR constexpr size_t kMaxMC = 512;
|
||||
HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
|
||||
HWY_INLINE_VAR constexpr size_t kMaxMC = 256;
|
||||
HWY_INLINE_VAR constexpr size_t kMaxNC = 6 * 1024;
|
||||
|
||||
// Upper bound for per-worker B storage on the stack. Chosen such that one row
|
||||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
|
||||
HWY_INLINE_VAR constexpr size_t kMaxKC = 6 * 1024;
|
||||
|
||||
// Policy classes for parallelism, implementing some of `Parallelism`.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue