Implement Continus Batching.

(1) A function GenerateTWithContinuousBatching is added to use continuous batching when enabled.

(2) The ContinuousQBatch is added as a subclass of QBatch to manage prefill, insert, used-kv-cache-collection.

(3) Also expanded the unit test to more diverse cases.

PiperOrigin-RevId: 836090261
This commit is contained in:
Charles Zhao 2025-11-23 23:53:28 -08:00 committed by Copybara-Service
parent 88a03b7ec4
commit 0e5f4cbf1b
4 changed files with 187 additions and 21 deletions

View File

@ -18,6 +18,8 @@
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include <optional>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/zones.h" #include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
@ -357,6 +359,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt, (void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
token, 0.0f); token, 0.0f);
qbatch.MutablePos(qi) = pos_in_prompt; qbatch.MutablePos(qi) = pos_in_prompt;
} else {
// This prevents the kv cache of eos_id to be written to last prefilled
// token.
qbatch.MutablePos(qi) = qbatch.Prompt(qi).size();
} }
qbatch.PrevToken(qi) = token; qbatch.PrevToken(qi) = token;
@ -589,6 +595,57 @@ static void GenerateT(const ModelConfig& config,
timing_info.NotifyGenerateDone(); timing_info.NotifyGenerateDone();
} }
// Same as GenerateT, but uses ContinuousQBatch.
static void GenerateTWithContinuousBatching(
const ModelConfig& config, const RuntimeConfig& runtime_config,
const AesCtrEngine& engine, const WeightsPtrs& weights,
Activations& activations, AllQueries& all_queries, MatMulEnv& env,
TimingInfo& timing_info) {
const size_t qbatch_size = runtime_config.decode_qbatch_size;
QBatch qbatch(0, qbatch_size, all_queries);
ContinuousQBatch prefill_batch(qbatch_size, all_queries);
hwy::BitSet4096<> non_eos;
const SampleFunc sample_token =
ChooseSampleFunc(runtime_config, engine, env.ctx);
int query_inserted = 0;
while (non_eos.Any() || query_inserted < all_queries.NumQueries()) {
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
// Continue if qi slot is still processing.
if (non_eos.Get(qi)) continue;
// Collect the kv_cache from the qi slot in the qbatch to the
// available_kv_caches_ in the prefill_batch.
prefill_batch.MaybeReleaseKV(qbatch.Single(qi));
// Prefill if no available prefilled queries to insert.
if (prefill_batch.ShouldPrefill()) {
prefill_batch.SetupNextBatchForPrefill();
PrefillTBatchOrQBatch(config, runtime_config, weights, activations,
prefill_batch, env, timing_info);
activations.SetBatchSize(qbatch.Size());
}
// Get the next query to insert to the generate batch.
std::optional<size_t> qi_to_insert = prefill_batch.GetNextToInsert();
if (qi_to_insert) {
qbatch.Insert(qi_to_insert.value(), qi);
query_inserted++;
non_eos.Set(qi);
StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos,
qi);
}
}
Transformer(config, runtime_config, weights, activations, qbatch, env);
SampleAndStream(config, runtime_config, weights, sample_token, activations,
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
}
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const ModelConfig& config, const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
@ -619,12 +676,17 @@ void GenerateBatchT(const ModelConfig& config,
all_queries[0].kv_cache.SeqLen(), env.ctx, all_queries[0].kv_cache.SeqLen(), env.ctx,
env.row_ptrs); env.row_ptrs);
if (runtime_config.use_continuous_batching) {
GenerateTWithContinuousBatching(config, runtime_config, engine, weights,
activations, all_queries, env, timing_info);
} else {
for (size_t start = 0; start < all_queries.NumQueries(); for (size_t start = 0; start < all_queries.NumQueries();
start += runtime_config.decode_qbatch_size) { start += runtime_config.decode_qbatch_size) {
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries); QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
// Generate a batch of one token for each of `qbatch.Size()` queries. // Generate a batch of one token for each of `qbatch.Size()` queries.
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env, GenerateT(config, runtime_config, engine, weights, activations, qbatch,
timing_info); env, timing_info);
}
} }
} }
@ -721,5 +783,64 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
} }
ContinuousQBatch::ContinuousQBatch(size_t max_size, AllQueries& queries)
: QBatch(0, max_size, queries) {
for (size_t i = start_; i < queries_.NumQueries(); ++i) {
if (!queries_[i].kv_cache.IsEmpty()) {
// Put the kv_cache to the available_kv_caches_ instead; leaving the
// kv_cache in the queries_ is very confusing. This simplifies the logic
// of kv_cache management.
available_kv_caches_.push_back(queries_[i].kv_cache);
queries_[i].kv_cache = KVCachePtr();
}
}
}
bool ContinuousQBatch::ShouldPrefill() const {
const bool no_available_to_insert = next_to_insert_ == next_to_prefill_;
const int more_queries_to_prefill = next_to_prefill_ < queries_.NumQueries();
return no_available_to_insert && more_queries_to_prefill;
}
void ContinuousQBatch::SetupNextBatchForPrefill() {
start_ = next_to_prefill_;
size_ = HWY_MIN(max_size_, queries_.NumQueries() - start_);
HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
query_idx_.clear();
query_idx_.reserve(size_);
for (size_t i = 0; i < size_; ++i) {
const size_t next_query_idx = start_ + i;
query_idx_.push_back(next_query_idx);
HWY_ASSERT(queries_[next_query_idx].kv_cache.IsEmpty());
queries_[next_query_idx].kv_cache = available_kv_caches_.back();
available_kv_caches_.pop_back();
}
next_to_prefill_ += size_;
}
std::optional<size_t> ContinuousQBatch::GetNextToInsert() {
if (next_to_insert_ == next_to_prefill_) {
return std::nullopt;
}
next_to_insert_++;
return next_to_insert_ - 1;
}
void ContinuousQBatch::MaybeReleaseKV(const QBatch& from) {
const int query_to_collect = from.QueryIdx(0);
// Only collect if the query to collect is not the same as the next query to
// insert. This happens at the beginning of each Generate call.
if (query_to_collect != next_to_insert_) {
// Only clear the KV cache if there are more queries to insert; Otherwise
// we get a crash because Transformer will still access that KV cache.
if (next_to_insert_ < queries_.NumQueries()) {
available_kv_caches_.push_back(from.KV(0));
ZeroInit(from.KV(0).kv_cache);
from.KV(0) = KVCachePtr();
}
}
}
} // namespace gcpp } // namespace gcpp
#endif // HWY_ONCE #endif // HWY_ONCE

View File

@ -18,6 +18,7 @@
#include <stdio.h> #include <stdio.h>
#include <optional>
#include <vector> #include <vector>
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
@ -89,17 +90,17 @@ struct AllQueries {
const hwy::Span<const PromptTokens>& prompts, const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCachePtr>& kv_caches, const hwy::Span<KVCachePtr>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) { const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
HWY_ASSERT(prompts.size() == kv_caches.size());
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0); HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
per_query_.reserve(kv_caches.size()); per_query_.reserve(prompts.size());
for (size_t i = 0; i < kv_caches.size(); ++i) { for (size_t i = 0; i < prompts.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); HWY_ASSERT(kv_caches.size() == 0 ||
kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{ per_query_.push_back(PerQuery{
.prompt = prompts[i], .prompt = prompts[i],
.mutable_pos = 0, .mutable_pos = 0,
.initial_pos = 0, .initial_pos = 0,
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i], .prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
.kv_cache = kv_caches[i], .kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i],
}); });
} }
} }
@ -142,10 +143,13 @@ class QBatch {
HWY_ASSERT(max_size_ <= kMaxBatchSize); HWY_ASSERT(max_size_ <= kMaxBatchSize);
HWY_DASSERT(size_ != 0); HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
for (int i = 0; i < size_; ++i) {
query_idx_.push_back(start_ + i);
}
} }
// Returns a single-query view starting at `qi` relative to this batch. // Returns a single-query view starting at `qi` relative to this batch.
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); } QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); }
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`. // How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
size_t Size() const { return size_; } size_t Size() const { return size_; }
@ -153,7 +157,7 @@ class QBatch {
// Returns index for use with `AllQueries` and `BatchStreamToken`. // Returns index for use with `AllQueries` and `BatchStreamToken`.
size_t QueryIdx(size_t qi) const { size_t QueryIdx(size_t qi) const {
HWY_DASSERT(qi < size_); HWY_DASSERT(qi < size_);
return start_ + qi; return query_idx_[qi];
} }
// Accessor functions to bridge the previous SoA and current AoS layout. // Accessor functions to bridge the previous SoA and current AoS layout.
@ -171,13 +175,48 @@ class QBatch {
KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; } int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
private: // let query_idx_[to] point to the from in the queries_; this is only used if
// the slot in the QBatch is less than the number of queries.
void Insert(size_t from, size_t to) {
if (from == to) return;
HWY_ASSERT(!queries_[from].kv_cache.IsEmpty());
HWY_ASSERT(queries_[to].kv_cache.IsEmpty());
// Conceptually, insert from.query to location to.
query_idx_[to] = from;
}
protected:
size_t start_; size_t start_;
size_t max_size_; size_t max_size_;
AllQueries& queries_; AllQueries& queries_;
std::vector<size_t> query_idx_;
size_t size_; size_t size_;
}; };
// Used for continuous batching.
class ContinuousQBatch : public QBatch {
public:
ContinuousQBatch(size_t max_size, AllQueries& queries);
// Whether we should prefill the next batch, i.e. next_to_insert_ ==
// next_to_prefill_.
bool ShouldPrefill() const;
// Setup the query_idx_ to point to the next group of queries to prefill.
void SetupNextBatchForPrefill();
// Get the next query to insert to the generate batch.
std::optional<size_t> GetNextToInsert();
// Collect the kv_cache from QBatch to available_kv_caches_.
void MaybeReleaseKV(const QBatch& from);
public:
int next_to_prefill_ = 0;
int next_to_insert_ = 0;
std::vector<KVCachePtr> available_kv_caches_;
};
struct TimingInfo { struct TimingInfo {
// be sure to populate prefill_start before calling NotifyPrefill. // be sure to populate prefill_start before calling NotifyPrefill.
void NotifyPrefill(size_t tokens) { void NotifyPrefill(size_t tokens) {

View File

@ -163,6 +163,9 @@ struct RuntimeConfig {
// default decision is likely sufficient because it is based on whether // default decision is likely sufficient because it is based on whether
// threads are successfully pinned. // threads are successfully pinned.
mutable Tristate use_spinning = Tristate::kDefault; mutable Tristate use_spinning = Tristate::kDefault;
// Whether to use continuous batching.
bool use_continuous_batching = false;
}; };
struct InferenceArgs : public ArgsBase<InferenceArgs> { struct InferenceArgs : public ArgsBase<InferenceArgs> {

View File

@ -28,6 +28,13 @@ namespace gcpp {
using KV_t = float; using KV_t = float;
// A non-owning view of a KVCache.
struct KVCachePtr {
bool IsEmpty() const { return kv_cache.Rows() == 0; }
size_t SeqLen() const { return kv_cache.Rows(); }
MatPtrT<KV_t> kv_cache;
};
struct KVCache { struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args, KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator); const Allocator& allocator);
@ -40,6 +47,8 @@ struct KVCache {
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
KVCachePtr ToPtr() { return KVCachePtr{.kv_cache = kv_cache}; }
private: private:
const Allocator& allocator_; const Allocator& allocator_;
@ -47,12 +56,6 @@ struct KVCache {
KVCache(const Extents2D& kv_extents, const Allocator& allocator); KVCache(const Extents2D& kv_extents, const Allocator& allocator);
}; };
// A non-owning view of a KVCache.
struct KVCachePtr {
size_t SeqLen() const { return kv_cache.Rows(); }
MatPtrT<KV_t> kv_cache;
};
// Convenience function to create views into KVCaches. // Convenience function to create views into KVCaches.
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches); std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);