From 0e5f4cbf1b2dd8dd3a165a4ef0e77afccce96e4e Mon Sep 17 00:00:00 2001 From: Charles Zhao Date: Sun, 23 Nov 2025 23:53:28 -0800 Subject: [PATCH] Implement Continus Batching. (1) A function GenerateTWithContinuousBatching is added to use continuous batching when enabled. (2) The ContinuousQBatch is added as a subclass of QBatch to manage prefill, insert, used-kv-cache-collection. (3) Also expanded the unit test to more diverse cases. PiperOrigin-RevId: 836090261 --- gemma/gemma.cc | 135 ++++++++++++++++++++++++++++++++++++++++++--- gemma/gemma.h | 55 +++++++++++++++--- gemma/gemma_args.h | 3 + gemma/kv_cache.h | 15 +++-- 4 files changed, 187 insertions(+), 21 deletions(-) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 7ed7f50..5dd665d 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -18,6 +18,8 @@ #include "gemma/gemma.h" +#include + #include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "util/zones.h" #ifndef HWY_DISABLED_TARGETS @@ -35,7 +37,7 @@ // After highway.h #include "gemma/attention.h" // includes highway.h #include "gemma/gemma-inl.h" -#include "gemma/vit.h" // includes highway.h +#include "gemma/vit.h" // includes highway.h #ifndef GEMMA_CC_ONCE #define GEMMA_CC_ONCE @@ -357,6 +359,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, (void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt, token, 0.0f); qbatch.MutablePos(qi) = pos_in_prompt; + } else { + // This prevents the kv cache of eos_id to be written to last prefilled + // token. + qbatch.MutablePos(qi) = qbatch.Prompt(qi).size(); } qbatch.PrevToken(qi) = token; @@ -589,6 +595,57 @@ static void GenerateT(const ModelConfig& config, timing_info.NotifyGenerateDone(); } +// Same as GenerateT, but uses ContinuousQBatch. +static void GenerateTWithContinuousBatching( + const ModelConfig& config, const RuntimeConfig& runtime_config, + const AesCtrEngine& engine, const WeightsPtrs& weights, + Activations& activations, AllQueries& all_queries, MatMulEnv& env, + TimingInfo& timing_info) { + const size_t qbatch_size = runtime_config.decode_qbatch_size; + + QBatch qbatch(0, qbatch_size, all_queries); + ContinuousQBatch prefill_batch(qbatch_size, all_queries); + + hwy::BitSet4096<> non_eos; + const SampleFunc sample_token = + ChooseSampleFunc(runtime_config, engine, env.ctx); + + int query_inserted = 0; + while (non_eos.Any() || query_inserted < all_queries.NumQueries()) { + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + // Continue if qi slot is still processing. + if (non_eos.Get(qi)) continue; + // Collect the kv_cache from the qi slot in the qbatch to the + // available_kv_caches_ in the prefill_batch. + prefill_batch.MaybeReleaseKV(qbatch.Single(qi)); + + // Prefill if no available prefilled queries to insert. + if (prefill_batch.ShouldPrefill()) { + prefill_batch.SetupNextBatchForPrefill(); + PrefillTBatchOrQBatch(config, runtime_config, weights, activations, + prefill_batch, env, timing_info); + activations.SetBatchSize(qbatch.Size()); + } + + // Get the next query to insert to the generate batch. + std::optional qi_to_insert = prefill_batch.GetNextToInsert(); + if (qi_to_insert) { + qbatch.Insert(qi_to_insert.value(), qi); + query_inserted++; + + non_eos.Set(qi); + StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos, + qi); + } + } + + Transformer(config, runtime_config, weights, activations, qbatch, env); + SampleAndStream(config, runtime_config, weights, sample_token, activations, + qbatch, env, non_eos, timing_info); + } + timing_info.NotifyGenerateDone(); +} + void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, const ModelConfig& config, const RuntimeConfig& runtime_config, @@ -619,12 +676,17 @@ void GenerateBatchT(const ModelConfig& config, all_queries[0].kv_cache.SeqLen(), env.ctx, env.row_ptrs); - for (size_t start = 0; start < all_queries.NumQueries(); - start += runtime_config.decode_qbatch_size) { - QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries); - // Generate a batch of one token for each of `qbatch.Size()` queries. - GenerateT(config, runtime_config, engine, weights, activations, qbatch, env, - timing_info); + if (runtime_config.use_continuous_batching) { + GenerateTWithContinuousBatching(config, runtime_config, engine, weights, + activations, all_queries, env, timing_info); + } else { + for (size_t start = 0; start < all_queries.NumQueries(); + start += runtime_config.decode_qbatch_size) { + QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries); + // Generate a batch of one token for each of `qbatch.Size()` queries. + GenerateT(config, runtime_config, engine, weights, activations, qbatch, + env, timing_info); + } } } @@ -721,5 +783,64 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } +ContinuousQBatch::ContinuousQBatch(size_t max_size, AllQueries& queries) + : QBatch(0, max_size, queries) { + for (size_t i = start_; i < queries_.NumQueries(); ++i) { + if (!queries_[i].kv_cache.IsEmpty()) { + // Put the kv_cache to the available_kv_caches_ instead; leaving the + // kv_cache in the queries_ is very confusing. This simplifies the logic + // of kv_cache management. + available_kv_caches_.push_back(queries_[i].kv_cache); + queries_[i].kv_cache = KVCachePtr(); + } + } +} + +bool ContinuousQBatch::ShouldPrefill() const { + const bool no_available_to_insert = next_to_insert_ == next_to_prefill_; + const int more_queries_to_prefill = next_to_prefill_ < queries_.NumQueries(); + return no_available_to_insert && more_queries_to_prefill; +} + +void ContinuousQBatch::SetupNextBatchForPrefill() { + start_ = next_to_prefill_; + size_ = HWY_MIN(max_size_, queries_.NumQueries() - start_); + HWY_DASSERT(size_ != 0); + HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); + query_idx_.clear(); + query_idx_.reserve(size_); + for (size_t i = 0; i < size_; ++i) { + const size_t next_query_idx = start_ + i; + query_idx_.push_back(next_query_idx); + HWY_ASSERT(queries_[next_query_idx].kv_cache.IsEmpty()); + queries_[next_query_idx].kv_cache = available_kv_caches_.back(); + available_kv_caches_.pop_back(); + } + next_to_prefill_ += size_; +} + +std::optional ContinuousQBatch::GetNextToInsert() { + if (next_to_insert_ == next_to_prefill_) { + return std::nullopt; + } + next_to_insert_++; + return next_to_insert_ - 1; +} + +void ContinuousQBatch::MaybeReleaseKV(const QBatch& from) { + const int query_to_collect = from.QueryIdx(0); + // Only collect if the query to collect is not the same as the next query to + // insert. This happens at the beginning of each Generate call. + if (query_to_collect != next_to_insert_) { + // Only clear the KV cache if there are more queries to insert; Otherwise + // we get a crash because Transformer will still access that KV cache. + if (next_to_insert_ < queries_.NumQueries()) { + available_kv_caches_.push_back(from.KV(0)); + ZeroInit(from.KV(0).kv_cache); + from.KV(0) = KVCachePtr(); + } + } +} + } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma/gemma.h b/gemma/gemma.h index 59875fb..158c1ff 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -18,6 +18,7 @@ #include +#include #include // IWYU pragma: begin_exports @@ -89,17 +90,17 @@ struct AllQueries { const hwy::Span& prompts, const hwy::Span& kv_caches, const hwy::Span& prefix_end = hwy::Span()) { - HWY_ASSERT(prompts.size() == kv_caches.size()); HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0); - per_query_.reserve(kv_caches.size()); - for (size_t i = 0; i < kv_caches.size(); ++i) { - HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); + per_query_.reserve(prompts.size()); + for (size_t i = 0; i < prompts.size(); ++i) { + HWY_ASSERT(kv_caches.size() == 0 || + kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); per_query_.push_back(PerQuery{ .prompt = prompts[i], .mutable_pos = 0, .initial_pos = 0, .prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i], - .kv_cache = kv_caches[i], + .kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i], }); } } @@ -142,10 +143,13 @@ class QBatch { HWY_ASSERT(max_size_ <= kMaxBatchSize); HWY_DASSERT(size_ != 0); HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); + for (int i = 0; i < size_; ++i) { + query_idx_.push_back(start_ + i); + } } // Returns a single-query view starting at `qi` relative to this batch. - QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); } + QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); } // How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`. size_t Size() const { return size_; } @@ -153,7 +157,7 @@ class QBatch { // Returns index for use with `AllQueries` and `BatchStreamToken`. size_t QueryIdx(size_t qi) const { HWY_DASSERT(qi < size_); - return start_ + qi; + return query_idx_[qi]; } // Accessor functions to bridge the previous SoA and current AoS layout. @@ -171,13 +175,48 @@ class QBatch { KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; } - private: + // let query_idx_[to] point to the from in the queries_; this is only used if + // the slot in the QBatch is less than the number of queries. + void Insert(size_t from, size_t to) { + if (from == to) return; + HWY_ASSERT(!queries_[from].kv_cache.IsEmpty()); + HWY_ASSERT(queries_[to].kv_cache.IsEmpty()); + // Conceptually, insert from.query to location to. + query_idx_[to] = from; + } + + protected: size_t start_; size_t max_size_; AllQueries& queries_; + std::vector query_idx_; size_t size_; }; +// Used for continuous batching. +class ContinuousQBatch : public QBatch { + public: + ContinuousQBatch(size_t max_size, AllQueries& queries); + + // Whether we should prefill the next batch, i.e. next_to_insert_ == + // next_to_prefill_. + bool ShouldPrefill() const; + + // Setup the query_idx_ to point to the next group of queries to prefill. + void SetupNextBatchForPrefill(); + + // Get the next query to insert to the generate batch. + std::optional GetNextToInsert(); + + // Collect the kv_cache from QBatch to available_kv_caches_. + void MaybeReleaseKV(const QBatch& from); + + public: + int next_to_prefill_ = 0; + int next_to_insert_ = 0; + std::vector available_kv_caches_; +}; + struct TimingInfo { // be sure to populate prefill_start before calling NotifyPrefill. void NotifyPrefill(size_t tokens) { diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 8536a78..78e2208 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -163,6 +163,9 @@ struct RuntimeConfig { // default decision is likely sufficient because it is based on whether // threads are successfully pinned. mutable Tristate use_spinning = Tristate::kDefault; + + // Whether to use continuous batching. + bool use_continuous_batching = false; }; struct InferenceArgs : public ArgsBase { diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 37a4d0e..3697116 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -28,6 +28,13 @@ namespace gcpp { using KV_t = float; +// A non-owning view of a KVCache. +struct KVCachePtr { + bool IsEmpty() const { return kv_cache.Rows() == 0; } + size_t SeqLen() const { return kv_cache.Rows(); } + MatPtrT kv_cache; +}; + struct KVCache { KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const Allocator& allocator); @@ -40,6 +47,8 @@ struct KVCache { MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] + KVCachePtr ToPtr() { return KVCachePtr{.kv_cache = kv_cache}; } + private: const Allocator& allocator_; @@ -47,12 +56,6 @@ struct KVCache { KVCache(const Extents2D& kv_extents, const Allocator& allocator); }; -// A non-owning view of a KVCache. -struct KVCachePtr { - size_t SeqLen() const { return kv_cache.Rows(); } - MatPtrT kv_cache; -}; - // Convenience function to create views into KVCaches. std::vector ToKVCachePtrs(const hwy::Span& kv_caches);