mirror of https://github.com/google/gemma.cpp.git
Implement Continus Batching.
(1) A function GenerateTWithContinuousBatching is added to use continuous batching when enabled. (2) The ContinuousQBatch is added as a subclass of QBatch to manage prefill, insert, used-kv-cache-collection. (3) Also expanded the unit test to more diverse cases. PiperOrigin-RevId: 836090261
This commit is contained in:
parent
88a03b7ec4
commit
0e5f4cbf1b
135
gemma/gemma.cc
135
gemma/gemma.cc
|
|
@ -18,6 +18,8 @@
|
||||||
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||||
#include "util/zones.h"
|
#include "util/zones.h"
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
|
|
@ -35,7 +37,7 @@
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "gemma/attention.h" // includes highway.h
|
#include "gemma/attention.h" // includes highway.h
|
||||||
#include "gemma/gemma-inl.h"
|
#include "gemma/gemma-inl.h"
|
||||||
#include "gemma/vit.h" // includes highway.h
|
#include "gemma/vit.h" // includes highway.h
|
||||||
|
|
||||||
#ifndef GEMMA_CC_ONCE
|
#ifndef GEMMA_CC_ONCE
|
||||||
#define GEMMA_CC_ONCE
|
#define GEMMA_CC_ONCE
|
||||||
|
|
@ -357,6 +359,10 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||||
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
|
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt,
|
||||||
token, 0.0f);
|
token, 0.0f);
|
||||||
qbatch.MutablePos(qi) = pos_in_prompt;
|
qbatch.MutablePos(qi) = pos_in_prompt;
|
||||||
|
} else {
|
||||||
|
// This prevents the kv cache of eos_id to be written to last prefilled
|
||||||
|
// token.
|
||||||
|
qbatch.MutablePos(qi) = qbatch.Prompt(qi).size();
|
||||||
}
|
}
|
||||||
|
|
||||||
qbatch.PrevToken(qi) = token;
|
qbatch.PrevToken(qi) = token;
|
||||||
|
|
@ -589,6 +595,57 @@ static void GenerateT(const ModelConfig& config,
|
||||||
timing_info.NotifyGenerateDone();
|
timing_info.NotifyGenerateDone();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Same as GenerateT, but uses ContinuousQBatch.
|
||||||
|
static void GenerateTWithContinuousBatching(
|
||||||
|
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||||
|
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||||
|
Activations& activations, AllQueries& all_queries, MatMulEnv& env,
|
||||||
|
TimingInfo& timing_info) {
|
||||||
|
const size_t qbatch_size = runtime_config.decode_qbatch_size;
|
||||||
|
|
||||||
|
QBatch qbatch(0, qbatch_size, all_queries);
|
||||||
|
ContinuousQBatch prefill_batch(qbatch_size, all_queries);
|
||||||
|
|
||||||
|
hwy::BitSet4096<> non_eos;
|
||||||
|
const SampleFunc sample_token =
|
||||||
|
ChooseSampleFunc(runtime_config, engine, env.ctx);
|
||||||
|
|
||||||
|
int query_inserted = 0;
|
||||||
|
while (non_eos.Any() || query_inserted < all_queries.NumQueries()) {
|
||||||
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
|
// Continue if qi slot is still processing.
|
||||||
|
if (non_eos.Get(qi)) continue;
|
||||||
|
// Collect the kv_cache from the qi slot in the qbatch to the
|
||||||
|
// available_kv_caches_ in the prefill_batch.
|
||||||
|
prefill_batch.MaybeReleaseKV(qbatch.Single(qi));
|
||||||
|
|
||||||
|
// Prefill if no available prefilled queries to insert.
|
||||||
|
if (prefill_batch.ShouldPrefill()) {
|
||||||
|
prefill_batch.SetupNextBatchForPrefill();
|
||||||
|
PrefillTBatchOrQBatch(config, runtime_config, weights, activations,
|
||||||
|
prefill_batch, env, timing_info);
|
||||||
|
activations.SetBatchSize(qbatch.Size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the next query to insert to the generate batch.
|
||||||
|
std::optional<size_t> qi_to_insert = prefill_batch.GetNextToInsert();
|
||||||
|
if (qi_to_insert) {
|
||||||
|
qbatch.Insert(qi_to_insert.value(), qi);
|
||||||
|
query_inserted++;
|
||||||
|
|
||||||
|
non_eos.Set(qi);
|
||||||
|
StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos,
|
||||||
|
qi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||||
|
SampleAndStream(config, runtime_config, weights, sample_token, activations,
|
||||||
|
qbatch, env, non_eos, timing_info);
|
||||||
|
}
|
||||||
|
timing_info.NotifyGenerateDone();
|
||||||
|
}
|
||||||
|
|
||||||
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
const ModelConfig& config,
|
const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
|
|
@ -619,12 +676,17 @@ void GenerateBatchT(const ModelConfig& config,
|
||||||
all_queries[0].kv_cache.SeqLen(), env.ctx,
|
all_queries[0].kv_cache.SeqLen(), env.ctx,
|
||||||
env.row_ptrs);
|
env.row_ptrs);
|
||||||
|
|
||||||
for (size_t start = 0; start < all_queries.NumQueries();
|
if (runtime_config.use_continuous_batching) {
|
||||||
start += runtime_config.decode_qbatch_size) {
|
GenerateTWithContinuousBatching(config, runtime_config, engine, weights,
|
||||||
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
|
activations, all_queries, env, timing_info);
|
||||||
// Generate a batch of one token for each of `qbatch.Size()` queries.
|
} else {
|
||||||
GenerateT(config, runtime_config, engine, weights, activations, qbatch, env,
|
for (size_t start = 0; start < all_queries.NumQueries();
|
||||||
timing_info);
|
start += runtime_config.decode_qbatch_size) {
|
||||||
|
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
|
||||||
|
// Generate a batch of one token for each of `qbatch.Size()` queries.
|
||||||
|
GenerateT(config, runtime_config, engine, weights, activations, qbatch,
|
||||||
|
env, timing_info);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -721,5 +783,64 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||||
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ContinuousQBatch::ContinuousQBatch(size_t max_size, AllQueries& queries)
|
||||||
|
: QBatch(0, max_size, queries) {
|
||||||
|
for (size_t i = start_; i < queries_.NumQueries(); ++i) {
|
||||||
|
if (!queries_[i].kv_cache.IsEmpty()) {
|
||||||
|
// Put the kv_cache to the available_kv_caches_ instead; leaving the
|
||||||
|
// kv_cache in the queries_ is very confusing. This simplifies the logic
|
||||||
|
// of kv_cache management.
|
||||||
|
available_kv_caches_.push_back(queries_[i].kv_cache);
|
||||||
|
queries_[i].kv_cache = KVCachePtr();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ContinuousQBatch::ShouldPrefill() const {
|
||||||
|
const bool no_available_to_insert = next_to_insert_ == next_to_prefill_;
|
||||||
|
const int more_queries_to_prefill = next_to_prefill_ < queries_.NumQueries();
|
||||||
|
return no_available_to_insert && more_queries_to_prefill;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ContinuousQBatch::SetupNextBatchForPrefill() {
|
||||||
|
start_ = next_to_prefill_;
|
||||||
|
size_ = HWY_MIN(max_size_, queries_.NumQueries() - start_);
|
||||||
|
HWY_DASSERT(size_ != 0);
|
||||||
|
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||||
|
query_idx_.clear();
|
||||||
|
query_idx_.reserve(size_);
|
||||||
|
for (size_t i = 0; i < size_; ++i) {
|
||||||
|
const size_t next_query_idx = start_ + i;
|
||||||
|
query_idx_.push_back(next_query_idx);
|
||||||
|
HWY_ASSERT(queries_[next_query_idx].kv_cache.IsEmpty());
|
||||||
|
queries_[next_query_idx].kv_cache = available_kv_caches_.back();
|
||||||
|
available_kv_caches_.pop_back();
|
||||||
|
}
|
||||||
|
next_to_prefill_ += size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<size_t> ContinuousQBatch::GetNextToInsert() {
|
||||||
|
if (next_to_insert_ == next_to_prefill_) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
next_to_insert_++;
|
||||||
|
return next_to_insert_ - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ContinuousQBatch::MaybeReleaseKV(const QBatch& from) {
|
||||||
|
const int query_to_collect = from.QueryIdx(0);
|
||||||
|
// Only collect if the query to collect is not the same as the next query to
|
||||||
|
// insert. This happens at the beginning of each Generate call.
|
||||||
|
if (query_to_collect != next_to_insert_) {
|
||||||
|
// Only clear the KV cache if there are more queries to insert; Otherwise
|
||||||
|
// we get a crash because Transformer will still access that KV cache.
|
||||||
|
if (next_to_insert_ < queries_.NumQueries()) {
|
||||||
|
available_kv_caches_.push_back(from.KV(0));
|
||||||
|
ZeroInit(from.KV(0).kv_cache);
|
||||||
|
from.KV(0) = KVCachePtr();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
|
|
@ -89,17 +90,17 @@ struct AllQueries {
|
||||||
const hwy::Span<const PromptTokens>& prompts,
|
const hwy::Span<const PromptTokens>& prompts,
|
||||||
const hwy::Span<KVCachePtr>& kv_caches,
|
const hwy::Span<KVCachePtr>& kv_caches,
|
||||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
|
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
|
||||||
HWY_ASSERT(prompts.size() == kv_caches.size());
|
|
||||||
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
|
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
|
||||||
per_query_.reserve(kv_caches.size());
|
per_query_.reserve(prompts.size());
|
||||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
for (size_t i = 0; i < prompts.size(); ++i) {
|
||||||
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
HWY_ASSERT(kv_caches.size() == 0 ||
|
||||||
|
kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||||
per_query_.push_back(PerQuery{
|
per_query_.push_back(PerQuery{
|
||||||
.prompt = prompts[i],
|
.prompt = prompts[i],
|
||||||
.mutable_pos = 0,
|
.mutable_pos = 0,
|
||||||
.initial_pos = 0,
|
.initial_pos = 0,
|
||||||
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
|
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
|
||||||
.kv_cache = kv_caches[i],
|
.kv_cache = kv_caches.size() == 0 ? KVCachePtr() : kv_caches[i],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -142,10 +143,13 @@ class QBatch {
|
||||||
HWY_ASSERT(max_size_ <= kMaxBatchSize);
|
HWY_ASSERT(max_size_ <= kMaxBatchSize);
|
||||||
HWY_DASSERT(size_ != 0);
|
HWY_DASSERT(size_ != 0);
|
||||||
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||||
|
for (int i = 0; i < size_; ++i) {
|
||||||
|
query_idx_.push_back(start_ + i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a single-query view starting at `qi` relative to this batch.
|
// Returns a single-query view starting at `qi` relative to this batch.
|
||||||
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); }
|
QBatch Single(size_t qi) const { return QBatch(QueryIdx(qi), 1, queries_); }
|
||||||
|
|
||||||
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
|
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
|
||||||
size_t Size() const { return size_; }
|
size_t Size() const { return size_; }
|
||||||
|
|
@ -153,7 +157,7 @@ class QBatch {
|
||||||
// Returns index for use with `AllQueries` and `BatchStreamToken`.
|
// Returns index for use with `AllQueries` and `BatchStreamToken`.
|
||||||
size_t QueryIdx(size_t qi) const {
|
size_t QueryIdx(size_t qi) const {
|
||||||
HWY_DASSERT(qi < size_);
|
HWY_DASSERT(qi < size_);
|
||||||
return start_ + qi;
|
return query_idx_[qi];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accessor functions to bridge the previous SoA and current AoS layout.
|
// Accessor functions to bridge the previous SoA and current AoS layout.
|
||||||
|
|
@ -171,13 +175,48 @@ class QBatch {
|
||||||
KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
KVCachePtr& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
||||||
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
|
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
|
||||||
|
|
||||||
private:
|
// let query_idx_[to] point to the from in the queries_; this is only used if
|
||||||
|
// the slot in the QBatch is less than the number of queries.
|
||||||
|
void Insert(size_t from, size_t to) {
|
||||||
|
if (from == to) return;
|
||||||
|
HWY_ASSERT(!queries_[from].kv_cache.IsEmpty());
|
||||||
|
HWY_ASSERT(queries_[to].kv_cache.IsEmpty());
|
||||||
|
// Conceptually, insert from.query to location to.
|
||||||
|
query_idx_[to] = from;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
size_t start_;
|
size_t start_;
|
||||||
size_t max_size_;
|
size_t max_size_;
|
||||||
AllQueries& queries_;
|
AllQueries& queries_;
|
||||||
|
std::vector<size_t> query_idx_;
|
||||||
size_t size_;
|
size_t size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Used for continuous batching.
|
||||||
|
class ContinuousQBatch : public QBatch {
|
||||||
|
public:
|
||||||
|
ContinuousQBatch(size_t max_size, AllQueries& queries);
|
||||||
|
|
||||||
|
// Whether we should prefill the next batch, i.e. next_to_insert_ ==
|
||||||
|
// next_to_prefill_.
|
||||||
|
bool ShouldPrefill() const;
|
||||||
|
|
||||||
|
// Setup the query_idx_ to point to the next group of queries to prefill.
|
||||||
|
void SetupNextBatchForPrefill();
|
||||||
|
|
||||||
|
// Get the next query to insert to the generate batch.
|
||||||
|
std::optional<size_t> GetNextToInsert();
|
||||||
|
|
||||||
|
// Collect the kv_cache from QBatch to available_kv_caches_.
|
||||||
|
void MaybeReleaseKV(const QBatch& from);
|
||||||
|
|
||||||
|
public:
|
||||||
|
int next_to_prefill_ = 0;
|
||||||
|
int next_to_insert_ = 0;
|
||||||
|
std::vector<KVCachePtr> available_kv_caches_;
|
||||||
|
};
|
||||||
|
|
||||||
struct TimingInfo {
|
struct TimingInfo {
|
||||||
// be sure to populate prefill_start before calling NotifyPrefill.
|
// be sure to populate prefill_start before calling NotifyPrefill.
|
||||||
void NotifyPrefill(size_t tokens) {
|
void NotifyPrefill(size_t tokens) {
|
||||||
|
|
|
||||||
|
|
@ -163,6 +163,9 @@ struct RuntimeConfig {
|
||||||
// default decision is likely sufficient because it is based on whether
|
// default decision is likely sufficient because it is based on whether
|
||||||
// threads are successfully pinned.
|
// threads are successfully pinned.
|
||||||
mutable Tristate use_spinning = Tristate::kDefault;
|
mutable Tristate use_spinning = Tristate::kDefault;
|
||||||
|
|
||||||
|
// Whether to use continuous batching.
|
||||||
|
bool use_continuous_batching = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,13 @@ namespace gcpp {
|
||||||
|
|
||||||
using KV_t = float;
|
using KV_t = float;
|
||||||
|
|
||||||
|
// A non-owning view of a KVCache.
|
||||||
|
struct KVCachePtr {
|
||||||
|
bool IsEmpty() const { return kv_cache.Rows() == 0; }
|
||||||
|
size_t SeqLen() const { return kv_cache.Rows(); }
|
||||||
|
MatPtrT<KV_t> kv_cache;
|
||||||
|
};
|
||||||
|
|
||||||
struct KVCache {
|
struct KVCache {
|
||||||
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||||
const Allocator& allocator);
|
const Allocator& allocator);
|
||||||
|
|
@ -40,6 +47,8 @@ struct KVCache {
|
||||||
|
|
||||||
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||||
|
|
||||||
|
KVCachePtr ToPtr() { return KVCachePtr{.kv_cache = kv_cache}; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const Allocator& allocator_;
|
const Allocator& allocator_;
|
||||||
|
|
||||||
|
|
@ -47,12 +56,6 @@ struct KVCache {
|
||||||
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
|
KVCache(const Extents2D& kv_extents, const Allocator& allocator);
|
||||||
};
|
};
|
||||||
|
|
||||||
// A non-owning view of a KVCache.
|
|
||||||
struct KVCachePtr {
|
|
||||||
size_t SeqLen() const { return kv_cache.Rows(); }
|
|
||||||
MatPtrT<KV_t> kv_cache;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Convenience function to create views into KVCaches.
|
// Convenience function to create views into KVCaches.
|
||||||
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);
|
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue