From 170a9b4690482dd1229a0876348cb898b34001f1 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Thu, 7 Mar 2024 14:08:48 +0800 Subject: [PATCH] Make `CreateKVCache` a free function rather than a method --- gemma.cc | 42 ++++++++++++++++++++++++------------------ gemma.h | 4 +++- run.cc | 2 +- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/gemma.cc b/gemma.cc index 31b3c38..ba6aafa 100644 --- a/gemma.cc +++ b/gemma.cc @@ -231,7 +231,6 @@ struct Activations { struct GemmaInterface { virtual ~GemmaInterface() = default; - virtual KVCache CreateKVCache() const = 0; virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0; virtual void Generate(size_t max_tokens, size_t max_generated_tokens, @@ -243,6 +242,23 @@ struct GemmaInterface { int verbosity) = 0; }; +template +KVCache CreateKVCache() { + return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, + Config::kSeqLen); +} + +KVCache CreateKVCache(Model type) { + switch (type) { + case Model::GEMMA_2B: + return CreateKVCache(); + case Model::GEMMA_7B: + return CreateKVCache(); + default: + HWY_ABORT("Model type %d unknown.", static_cast(type)); + } +} + template struct GemmaImpl : public GemmaInterface { GemmaImpl( // const LoaderArgs& args, @@ -256,8 +272,6 @@ struct GemmaImpl : public GemmaInterface { c_weights->c_layer_ptrs.~CompressedLayerPointers(); } - KVCache CreateKVCache() const override; - const sentencepiece::SentencePieceProcessor* Tokenizer() const override { return tokenizer.get(); } @@ -739,6 +753,13 @@ HWY_EXPORT(GetCompressedWeightsT); HWY_EXPORT(Generate2B); HWY_EXPORT(Generate7B); +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { + KVCache kv_cache = {}; + kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + return kv_cache; +} + template GemmaImpl::GemmaImpl( std::unique_ptr& tokenizer, @@ -755,17 +776,6 @@ GemmaImpl::GemmaImpl( // HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); } -template -KVCache GemmaImpl::CreateKVCache() const { - constexpr const size_t size_cache_pos = Config::kLayers * Config::kKVHeads * - Config::kQKVDim; - constexpr const size_t seq_len = Config::kSeqLen; - KVCache kv_cache = {}; - kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - return kv_cache; -} - template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, @@ -815,10 +825,6 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { } Gemma::~Gemma() = default; // after GemmaInterface is defined -KVCache Gemma::CreateKVCache() const { - return impl_->CreateKVCache(); -} - const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { return impl_->Tokenizer(); } diff --git a/gemma.h b/gemma.h index 03da2e5..58fd74a 100644 --- a/gemma.h +++ b/gemma.h @@ -157,13 +157,15 @@ struct Gemma { Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - KVCache CreateKVCache() const; const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; gcpp::ModelTraining model_training; }; +KVCache CreateKVCache(Model type); // convenient workaround for now +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); + // StreamFunc is called with (token, probability). For prompt tokens, // probability is 0.0f. using StreamFunc = std::function; diff --git a/run.cc b/run.cc index ff7ed3d..40be63e 100644 --- a/run.cc +++ b/run.cc @@ -236,7 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader, pool); - auto kv_cache = model.CreateKVCache(); + auto kv_cache = CreateKVCache(loader.ModelType()); if (const char* error = inference.Validate()) { ShowHelp(loader, inference, app);