Implement Continus Batching.

(1) A function GenerateTWithContinuousBatching is added to use continuous batching when enabled.

(2) The ContinuousQBatch is added as a subclass of QBatch to manage prefill, insert, used-kv-cache-collection.

(3) Also expanded the unit test to more diverse cases.

PiperOrigin-RevId: 836090261
This commit is contained in:
Charles Zhao 2025-11-23 23:53:28 -08:00 committed by Copybara-Service
parent 88a03b7ec4
commit 0e5f4cbf1b
4 changed files with 187 additions and 21 deletions

View File

@ -18,6 +18,8 @@
#include "gemma/gemma.h"
#include <optional>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS
@ -357,6 +359,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
token, 0.0f);
qbatch.MutablePos(qi) = pos_in_prompt;
} else {
// This prevents the kv cache of eos_id to be written to last prefilled
// token.
qbatch.MutablePos(qi) = qbatch.Prompt(qi).size();
}
qbatch.PrevToken(qi) = token;
@ -589,6 +595,57 @@ static void GenerateT(const ModelConfig& config,
timing_info.NotifyGenerateDone();
}
// Same as GenerateT, but uses ContinuousQBatch.
static void GenerateTWithContinuousBatching(
const ModelConfig& config, const RuntimeConfig& runtime_config,
const AesCtrEngine& engine, const WeightsPtrs& weights,
Activations& activations, AllQueries& all_queries, MatMulEnv& env,
TimingInfo& timing_info) {
const size_t qbatch_size = runtime_config.decode_qbatch_size;
QBatch qbatch(0, qbatch_size, all_queries);
ContinuousQBatch prefill_batch(qbatch_size, all_queries);
hwy::BitSet4096<> non_eos;
const SampleFunc sample_token =
ChooseSampleFunc(runtime_config, engine, env.ctx);
int query_inserted = 0;
while (non_eos.Any() || query_inserted < all_queries.NumQueries()) {
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
// Continue if qi slot is still processing.
if (non_eos.Get(qi)) continue;
// Collect the kv_cache from the qi slot in the qbatch to the
// available_kv_caches_ in the prefill_batch.
prefill_batch.MaybeReleaseKV(qbatch.Single(qi));
// Prefill if no available prefilled queries to insert.
if (prefill_batch.ShouldPrefill()) {
prefill_batch.SetupNextBatchForPrefill();
PrefillTBatchOrQBatch(config, runtime_config, weights, activations,
prefill_batch, env, timing_info);
activations.SetBatchSize(qbatch.Size());
}
// Get the next query to insert to the generate batch.
std::optional<size_t> qi_to_insert = prefill_batch.GetNextToInsert();
if (qi_to_insert) {
qbatch.Insert(qi_to_insert.value(), qi);
query_inserted++;
non_eos.Set(qi);
StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos,
qi);
}
}
Transformer(config, runtime_config, weights, activations, qbatch, env);
SampleAndStream(config, runtime_config, weights, sample_token, activations,
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
}
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const ModelConfig& config,
const RuntimeConfig& runtime_config,
@ -619,12 +676,17 @@ void GenerateBatchT(const ModelConfig& config,
all_queries[0].kv_cache.SeqLen(), env.ctx,
env.row_ptrs);
if (runtime_config.use_continuous_batching) {
GenerateTWithContinuousBatching(config, runtime_config, engine, weights,
activations, all_queries, env, timing_info);
} else {
for (size_t start = 0; start < all_queries.NumQueries();
start += runtime_config.decode_qbatch_size) {
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
// Generate a batch of one token for each of `qbatch.Size()` queries.
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env,
timing_info);
GenerateT(config, runtime_config, engine, weights, activations, qbatch,
env, timing_info);
}
}
}
@ -721,5 +783,64 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
ContinuousQBatch::ContinuousQBatch(size_t max_size, AllQueries& queries)
: QBatch(0, max_size, queries) {
for (size_t i = start_; i < queries_.NumQueries(); ++i) {
if (!queries_[i].kv_cache.IsEmpty()) {
// Put the kv_cache to the available_kv_caches_ instead; leaving the
// kv_cache in the queries_ is very confusing. This simplifies the logic
// of kv_cache management.
available_kv_caches_.push_back(queries_[i].kv_cache);
queries_[i].kv_cache = KVCachePtr();
}
}
}
bool ContinuousQBatch::ShouldPrefill() const {
const bool no_available_to_insert = next_to_insert_ == next_to_prefill_;
const int more_queries_to_prefill = next_to_prefill_ < queries_.NumQueries();
return no_available_to_insert && more_queries_to_prefill;
}
void ContinuousQBatch::SetupNextBatchForPrefill() {
start_ = next_to_prefill_;
size_ = HWY_MIN(max_size_, queries_.NumQueries() - start_);
HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
query_idx_.clear();
query_idx_.reserve(size_);
for (size_t i = 0; i < size_; ++i) {
const size_t next_query_idx = start_ + i;
query_idx_.push_back(next_query_idx);
HWY_ASSERT(queries_[next_query_idx].kv_cache.IsEmpty());
queries_[next_query_idx].kv_cache = available_kv_caches_.back();
available_kv_caches_.pop_back();
}
next_to_prefill_ += size_;
}
std::optional<size_t> ContinuousQBatch::GetNextToInsert() {
if (next_to_insert_ == next_to_prefill_) {
return std::nullopt;
}
next_to_insert_++;
return next_to_insert_ - 1;
}
void ContinuousQBatch::MaybeReleaseKV(const QBatch& from) {
const int query_to_collect = from.QueryIdx(0);
// Only collect if the query to collect is not the same as the next query to
// insert. This happens at the beginning of each Generate call.
if (query_to_collect != next_to_insert_) {
// Only clear the KV cache if there are more queries to insert; Otherwise
// we get a crash because Transformer will still access that KV cache.
if (next_to_insert_ < queries_.NumQueries()) {
available_kv_caches_.push_back(from.KV(0));
ZeroInit(from.KV(0).kv_cache);
from.KV(0) = KVCachePtr();
}
}
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -18,6 +18,7 @@
#include <stdio.h>
#include <optional>
#include <vector>
// IWYU pragma: begin_exports
@ -89,17 +90,17 @@ struct AllQueries {
const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCachePtr>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
HWY_ASSERT(prompts.size() == kv_caches.size());
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.reserve(prompts.size());
for (size_t i = 0; i < prompts.size(); ++i) {
HWY_ASSERT(kv_caches.size() == 0 ||
kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{
.prompt = prompts[i],
.mutable_pos = 0,
.initial_pos = 0,
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
.kv_cache = kv_caches[i],
.kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i],
});
}
}
@ -142,10 +143,13 @@ class QBatch {
HWY_ASSERT(max_size_ <= kMaxBatchSize);
HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
for (int i = 0; i < size_; ++i) {
query_idx_.push_back(start_ + i);
}
}
// Returns a single-query view starting at `qi` relative to this batch.
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); }
QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); }
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
size_t Size() const { return size_; }
@ -153,7 +157,7 @@ class QBatch {
// Returns index for use with `AllQueries` and `BatchStreamToken`.
size_t QueryIdx(size_t qi) const {
HWY_DASSERT(qi < size_);
return start_ + qi;
return query_idx_[qi];
}
// Accessor functions to bridge the previous SoA and current AoS layout.
@ -171,13 +175,48 @@ class QBatch {
KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
private:
// let query_idx_[to] point to the from in the queries_; this is only used if
// the slot in the QBatch is less than the number of queries.
void Insert(size_t from, size_t to) {
if (from == to) return;
HWY_ASSERT(!queries_[from].kv_cache.IsEmpty());
HWY_ASSERT(queries_[to].kv_cache.IsEmpty());
// Conceptually, insert from.query to location to.
query_idx_[to] = from;
}
protected:
size_t start_;
size_t max_size_;
AllQueries& queries_;
std::vector<size_t> query_idx_;
size_t size_;
};
// Used for continuous batching.
class ContinuousQBatch : public QBatch {
public:
ContinuousQBatch(size_t max_size, AllQueries& queries);
// Whether we should prefill the next batch, i.e. next_to_insert_ ==
// next_to_prefill_.
bool ShouldPrefill() const;
// Setup the query_idx_ to point to the next group of queries to prefill.
void SetupNextBatchForPrefill();
// Get the next query to insert to the generate batch.
std::optional<size_t> GetNextToInsert();
// Collect the kv_cache from QBatch to available_kv_caches_.
void MaybeReleaseKV(const QBatch& from);
public:
int next_to_prefill_ = 0;
int next_to_insert_ = 0;
std::vector<KVCachePtr> available_kv_caches_;
};
struct TimingInfo {
// be sure to populate prefill_start before calling NotifyPrefill.
void NotifyPrefill(size_t tokens) {

View File

@ -163,6 +163,9 @@ struct RuntimeConfig {
// default decision is likely sufficient because it is based on whether
// threads are successfully pinned.
mutable Tristate use_spinning = Tristate::kDefault;
// Whether to use continuous batching.
bool use_continuous_batching = false;
};
struct InferenceArgs : public ArgsBase<InferenceArgs> {

View File

@ -28,6 +28,13 @@ namespace gcpp {
using KV_t = float;
// A non-owning view of a KVCache.
struct KVCachePtr {
bool IsEmpty() const { return kv_cache.Rows() == 0; }
size_t SeqLen() const { return kv_cache.Rows(); }
MatPtrT<KV_t> kv_cache;
};
struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator);
@ -40,6 +47,8 @@ struct KVCache {
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
KVCachePtr ToPtr() { return KVCachePtr{.kv_cache = kv_cache}; }
private:
const Allocator& allocator_;
@ -47,12 +56,6 @@ struct KVCache {
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
};
// A non-owning view of a KVCache.
struct KVCachePtr {
size_t SeqLen() const { return kv_cache.Rows(); }
MatPtrT<KV_t> kv_cache;
};
// Convenience function to create views into KVCaches.
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);