mirror of https://github.com/google/gemma.cpp.git
Implement Continus Batching.
(1) A function GenerateTWithContinuousBatching is added to use continuous batching when enabled. (2) The ContinuousQBatch is added as a subclass of QBatch to manage prefill, insert, used-kv-cache-collection. (3) Also expanded the unit test to more diverse cases. PiperOrigin-RevId: 836090261
This commit is contained in:
parent
88a03b7ec4
commit
0e5f4cbf1b
135
gemma/gemma.cc
135
gemma/gemma.cc
|
|
@ -18,6 +18,8 @@
|
|||
|
||||
#include "gemma/gemma.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "util/zones.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
|
|
@ -35,7 +37,7 @@
|
|||
// After highway.h
|
||||
#include "gemma/attention.h" // includes highway.h
|
||||
#include "gemma/gemma-inl.h"
|
||||
#include "gemma/vit.h" // includes highway.h
|
||||
#include "gemma/vit.h" // includes highway.h
|
||||
|
||||
#ifndef GEMMA_CC_ONCE
|
||||
#define GEMMA_CC_ONCE
|
||||
|
|
@ -357,6 +359,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
|||
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
|
||||
token, 0.0f);
|
||||
qbatch.MutablePos(qi) = pos_in_prompt;
|
||||
} else {
|
||||
// This prevents the kv cache of eos_id to be written to last prefilled
|
||||
// token.
|
||||
qbatch.MutablePos(qi) = qbatch.Prompt(qi).size();
|
||||
}
|
||||
|
||||
qbatch.PrevToken(qi) = token;
|
||||
|
|
@ -589,6 +595,57 @@ static void GenerateT(const ModelConfig& config,
|
|||
timing_info.NotifyGenerateDone();
|
||||
}
|
||||
|
||||
// Same as GenerateT, but uses ContinuousQBatch.
|
||||
static void GenerateTWithContinuousBatching(
|
||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||
Activations& activations, AllQueries& all_queries, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
const size_t qbatch_size = runtime_config.decode_qbatch_size;
|
||||
|
||||
QBatch qbatch(0, qbatch_size, all_queries);
|
||||
ContinuousQBatch prefill_batch(qbatch_size, all_queries);
|
||||
|
||||
hwy::BitSet4096<> non_eos;
|
||||
const SampleFunc sample_token =
|
||||
ChooseSampleFunc(runtime_config, engine, env.ctx);
|
||||
|
||||
int query_inserted = 0;
|
||||
while (non_eos.Any() || query_inserted < all_queries.NumQueries()) {
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
// Continue if qi slot is still processing.
|
||||
if (non_eos.Get(qi)) continue;
|
||||
// Collect the kv_cache from the qi slot in the qbatch to the
|
||||
// available_kv_caches_ in the prefill_batch.
|
||||
prefill_batch.MaybeReleaseKV(qbatch.Single(qi));
|
||||
|
||||
// Prefill if no available prefilled queries to insert.
|
||||
if (prefill_batch.ShouldPrefill()) {
|
||||
prefill_batch.SetupNextBatchForPrefill();
|
||||
PrefillTBatchOrQBatch(config, runtime_config, weights, activations,
|
||||
prefill_batch, env, timing_info);
|
||||
activations.SetBatchSize(qbatch.Size());
|
||||
}
|
||||
|
||||
// Get the next query to insert to the generate batch.
|
||||
std::optional<size_t> qi_to_insert = prefill_batch.GetNextToInsert();
|
||||
if (qi_to_insert) {
|
||||
qbatch.Insert(qi_to_insert.value(), qi);
|
||||
query_inserted++;
|
||||
|
||||
non_eos.Set(qi);
|
||||
StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos,
|
||||
qi);
|
||||
}
|
||||
}
|
||||
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
SampleAndStream(config, runtime_config, weights, sample_token, activations,
|
||||
qbatch, env, non_eos, timing_info);
|
||||
}
|
||||
timing_info.NotifyGenerateDone();
|
||||
}
|
||||
|
||||
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
|
|
@ -619,12 +676,17 @@ void GenerateBatchT(const ModelConfig& config,
|
|||
all_queries[0].kv_cache.SeqLen(), env.ctx,
|
||||
env.row_ptrs);
|
||||
|
||||
for (size_t start = 0; start < all_queries.NumQueries();
|
||||
start += runtime_config.decode_qbatch_size) {
|
||||
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
|
||||
// Generate a batch of one token for each of `qbatch.Size()` queries.
|
||||
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env,
|
||||
timing_info);
|
||||
if (runtime_config.use_continuous_batching) {
|
||||
GenerateTWithContinuousBatching(config, runtime_config, engine, weights,
|
||||
activations, all_queries, env, timing_info);
|
||||
} else {
|
||||
for (size_t start = 0; start < all_queries.NumQueries();
|
||||
start += runtime_config.decode_qbatch_size) {
|
||||
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
|
||||
// Generate a batch of one token for each of `qbatch.Size()` queries.
|
||||
GenerateT(config, runtime_config, engine, weights, activations, qbatch,
|
||||
env, timing_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -721,5 +783,64 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
|||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
ContinuousQBatch::ContinuousQBatch(size_t max_size, AllQueries& queries)
|
||||
: QBatch(0, max_size, queries) {
|
||||
for (size_t i = start_; i < queries_.NumQueries(); ++i) {
|
||||
if (!queries_[i].kv_cache.IsEmpty()) {
|
||||
// Put the kv_cache to the available_kv_caches_ instead; leaving the
|
||||
// kv_cache in the queries_ is very confusing. This simplifies the logic
|
||||
// of kv_cache management.
|
||||
available_kv_caches_.push_back(queries_[i].kv_cache);
|
||||
queries_[i].kv_cache = KVCachePtr();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ContinuousQBatch::ShouldPrefill() const {
|
||||
const bool no_available_to_insert = next_to_insert_ == next_to_prefill_;
|
||||
const int more_queries_to_prefill = next_to_prefill_ < queries_.NumQueries();
|
||||
return no_available_to_insert && more_queries_to_prefill;
|
||||
}
|
||||
|
||||
void ContinuousQBatch::SetupNextBatchForPrefill() {
|
||||
start_ = next_to_prefill_;
|
||||
size_ = HWY_MIN(max_size_, queries_.NumQueries() - start_);
|
||||
HWY_DASSERT(size_ != 0);
|
||||
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||
query_idx_.clear();
|
||||
query_idx_.reserve(size_);
|
||||
for (size_t i = 0; i < size_; ++i) {
|
||||
const size_t next_query_idx = start_ + i;
|
||||
query_idx_.push_back(next_query_idx);
|
||||
HWY_ASSERT(queries_[next_query_idx].kv_cache.IsEmpty());
|
||||
queries_[next_query_idx].kv_cache = available_kv_caches_.back();
|
||||
available_kv_caches_.pop_back();
|
||||
}
|
||||
next_to_prefill_ += size_;
|
||||
}
|
||||
|
||||
std::optional<size_t> ContinuousQBatch::GetNextToInsert() {
|
||||
if (next_to_insert_ == next_to_prefill_) {
|
||||
return std::nullopt;
|
||||
}
|
||||
next_to_insert_++;
|
||||
return next_to_insert_ - 1;
|
||||
}
|
||||
|
||||
void ContinuousQBatch::MaybeReleaseKV(const QBatch& from) {
|
||||
const int query_to_collect = from.QueryIdx(0);
|
||||
// Only collect if the query to collect is not the same as the next query to
|
||||
// insert. This happens at the beginning of each Generate call.
|
||||
if (query_to_collect != next_to_insert_) {
|
||||
// Only clear the KV cache if there are more queries to insert; Otherwise
|
||||
// we get a crash because Transformer will still access that KV cache.
|
||||
if (next_to_insert_ < queries_.NumQueries()) {
|
||||
available_kv_caches_.push_back(from.KV(0));
|
||||
ZeroInit(from.KV(0).kv_cache);
|
||||
from.KV(0) = KVCachePtr();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
|
|
@ -89,17 +90,17 @@ struct AllQueries {
|
|||
const hwy::Span<const PromptTokens>& prompts,
|
||||
const hwy::Span<KVCachePtr>& kv_caches,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
|
||||
HWY_ASSERT(prompts.size() == kv_caches.size());
|
||||
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
|
||||
per_query_.reserve(kv_caches.size());
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||
per_query_.reserve(prompts.size());
|
||||
for (size_t i = 0; i < prompts.size(); ++i) {
|
||||
HWY_ASSERT(kv_caches.size() == 0 ||
|
||||
kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||
per_query_.push_back(PerQuery{
|
||||
.prompt = prompts[i],
|
||||
.mutable_pos = 0,
|
||||
.initial_pos = 0,
|
||||
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
|
||||
.kv_cache = kv_caches[i],
|
||||
.kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -142,10 +143,13 @@ class QBatch {
|
|||
HWY_ASSERT(max_size_ <= kMaxBatchSize);
|
||||
HWY_DASSERT(size_ != 0);
|
||||
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||
for (int i = 0; i < size_; ++i) {
|
||||
query_idx_.push_back(start_ + i);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a single-query view starting at `qi` relative to this batch.
|
||||
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); }
|
||||
QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); }
|
||||
|
||||
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
|
||||
size_t Size() const { return size_; }
|
||||
|
|
@ -153,7 +157,7 @@ class QBatch {
|
|||
// Returns index for use with `AllQueries` and `BatchStreamToken`.
|
||||
size_t QueryIdx(size_t qi) const {
|
||||
HWY_DASSERT(qi < size_);
|
||||
return start_ + qi;
|
||||
return query_idx_[qi];
|
||||
}
|
||||
|
||||
// Accessor functions to bridge the previous SoA and current AoS layout.
|
||||
|
|
@ -171,13 +175,48 @@ class QBatch {
|
|||
KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
||||
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
|
||||
|
||||
private:
|
||||
// let query_idx_[to] point to the from in the queries_; this is only used if
|
||||
// the slot in the QBatch is less than the number of queries.
|
||||
void Insert(size_t from, size_t to) {
|
||||
if (from == to) return;
|
||||
HWY_ASSERT(!queries_[from].kv_cache.IsEmpty());
|
||||
HWY_ASSERT(queries_[to].kv_cache.IsEmpty());
|
||||
// Conceptually, insert from.query to location to.
|
||||
query_idx_[to] = from;
|
||||
}
|
||||
|
||||
protected:
|
||||
size_t start_;
|
||||
size_t max_size_;
|
||||
AllQueries& queries_;
|
||||
std::vector<size_t> query_idx_;
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
// Used for continuous batching.
|
||||
class ContinuousQBatch : public QBatch {
|
||||
public:
|
||||
ContinuousQBatch(size_t max_size, AllQueries& queries);
|
||||
|
||||
// Whether we should prefill the next batch, i.e. next_to_insert_ ==
|
||||
// next_to_prefill_.
|
||||
bool ShouldPrefill() const;
|
||||
|
||||
// Setup the query_idx_ to point to the next group of queries to prefill.
|
||||
void SetupNextBatchForPrefill();
|
||||
|
||||
// Get the next query to insert to the generate batch.
|
||||
std::optional<size_t> GetNextToInsert();
|
||||
|
||||
// Collect the kv_cache from QBatch to available_kv_caches_.
|
||||
void MaybeReleaseKV(const QBatch& from);
|
||||
|
||||
public:
|
||||
int next_to_prefill_ = 0;
|
||||
int next_to_insert_ = 0;
|
||||
std::vector<KVCachePtr> available_kv_caches_;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
// be sure to populate prefill_start before calling NotifyPrefill.
|
||||
void NotifyPrefill(size_t tokens) {
|
||||
|
|
|
|||
|
|
@ -163,6 +163,9 @@ struct RuntimeConfig {
|
|||
// default decision is likely sufficient because it is based on whether
|
||||
// threads are successfully pinned.
|
||||
mutable Tristate use_spinning = Tristate::kDefault;
|
||||
|
||||
// Whether to use continuous batching.
|
||||
bool use_continuous_batching = false;
|
||||
};
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
|
|
|
|||
|
|
@ -28,6 +28,13 @@ namespace gcpp {
|
|||
|
||||
using KV_t = float;
|
||||
|
||||
// A non-owning view of a KVCache.
|
||||
struct KVCachePtr {
|
||||
bool IsEmpty() const { return kv_cache.Rows() == 0; }
|
||||
size_t SeqLen() const { return kv_cache.Rows(); }
|
||||
MatPtrT<KV_t> kv_cache;
|
||||
};
|
||||
|
||||
struct KVCache {
|
||||
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||
const Allocator& allocator);
|
||||
|
|
@ -40,6 +47,8 @@ struct KVCache {
|
|||
|
||||
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
|
||||
KVCachePtr ToPtr() { return KVCachePtr{.kv_cache = kv_cache}; }
|
||||
|
||||
private:
|
||||
const Allocator& allocator_;
|
||||
|
||||
|
|
@ -47,12 +56,6 @@ struct KVCache {
|
|||
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
|
||||
};
|
||||
|
||||
// A non-owning view of a KVCache.
|
||||
struct KVCachePtr {
|
||||
size_t SeqLen() const { return kv_cache.Rows(); }
|
||||
MatPtrT<KV_t> kv_cache;
|
||||
};
|
||||
|
||||
// Convenience function to create views into KVCaches.
|
||||
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue