diff --git a/gemma.cc b/gemma.cc index bbd86c3..31b3c38 100644 --- a/gemma.cc +++ b/gemma.cc @@ -231,12 +231,13 @@ struct Activations { struct GemmaInterface { virtual ~GemmaInterface() = default; + virtual KVCache CreateKVCache() const = 0; virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0; virtual void Generate(size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, + size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) = 0; @@ -255,22 +256,24 @@ struct GemmaImpl : public GemmaInterface { c_weights->c_layer_ptrs.~CompressedLayerPointers(); } - const sentencepiece::SentencePieceProcessor* Tokenizer() const { + KVCache CreateKVCache() const override; + + const sentencepiece::SentencePieceProcessor* Tokenizer() const override { return tokenizer.get(); } void Generate(size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937&, int verbosity); + const AcceptFunc& accept_token, std::mt19937&, + int verbosity) override; std::unique_ptr tokenizer; hwy::AlignedFreeUniquePtr compressed_weights; hwy::AlignedUniquePtr> prefill; hwy::AlignedUniquePtr> state; - KVCache kv_cache; }; } // namespace gcpp @@ -503,7 +506,7 @@ void Transformer(int token, size_t pos, template void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t pos, + const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, @@ -517,7 +520,6 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, const CompressedWeights& c_weights = *reinterpret_cast*>( gemma.compressed_weights.get()); - KVCache& kv_cache = gemma.kv_cache; int token; // pos indexes the KV cache. In the first turn of a chat, pos = 0. @@ -612,23 +614,25 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, void Generate2B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, pool, inner_pool, stream_token, accept_token, gen, - verbosity); + start_pos, kv_cache, pool, inner_pool, stream_token, + accept_token, gen, verbosity); } void Generate7B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, pool, inner_pool, stream_token, accept_token, gen, - verbosity); + start_pos, kv_cache, pool, inner_pool, stream_token, + accept_token, gen, verbosity); } // Calls func(name, float*, CompressedArray&) for each tensor. float* is null @@ -735,13 +739,6 @@ HWY_EXPORT(GetCompressedWeightsT); HWY_EXPORT(Generate2B); HWY_EXPORT(Generate7B); -KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { - KVCache kv_cache = {}; - kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - return kv_cache; -} - template GemmaImpl::GemmaImpl( std::unique_ptr& tokenizer, @@ -753,33 +750,43 @@ GemmaImpl::GemmaImpl( // HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), prefill(hwy::MakeUniqueAligned>()), state(hwy::MakeUniqueAligned>()), - kv_cache( - CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, - Config::kSeqLen)), tokenizer(std::move(tokenizer)) { // PROFILER_ZONE("Startup.tokenizer"); // HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); } +template +KVCache GemmaImpl::CreateKVCache() const { + constexpr const size_t size_cache_pos = Config::kLayers * Config::kKVHeads * + Config::kQKVDim; + constexpr const size_t seq_len = Config::kSeqLen; + KVCache kv_cache = {}; + kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + return kv_cache; +} + template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { + const std::vector& prompt, size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate2B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - pool, inner_pool, stream_token, accept_token, gen, verbosity); + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { + const std::vector& prompt, size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate7B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - pool, inner_pool, stream_token, accept_token, gen, verbosity); + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } // TODO: Make Gemma type independent of LoaderArgs, create a factory function @@ -808,20 +815,24 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { } Gemma::~Gemma() = default; // after GemmaInterface is defined +KVCache Gemma::CreateKVCache() const { + return impl_->CreateKVCache(); +} + const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { return impl_->Tokenizer(); } void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, - start_pos, pool, inner_pool, stream_token, accept_token, - gen, verbosity); + start_pos, kv_cache, pool, inner_pool, stream_token, + accept_token, gen, verbosity); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } diff --git a/gemma.h b/gemma.h index f6361e1..03da2e5 100644 --- a/gemma.h +++ b/gemma.h @@ -157,6 +157,7 @@ struct Gemma { Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. + KVCache CreateKVCache() const; const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; @@ -211,7 +212,7 @@ struct InferenceArgs : public ArgsBase { void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity); diff --git a/run.cc b/run.cc index 4c4f132..ff7ed3d 100644 --- a/run.cc +++ b/run.cc @@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, std::cerr << "\n"; } -void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const InferenceArgs& args, - int verbosity, const gcpp::AcceptFunc& accept_token, - std::string& eot_line) { +void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const InferenceArgs& args, int verbosity, + const gcpp::AcceptFunc& accept_token, std::string& eot_line) { PROFILER_ZONE("Gen.misc"); int abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn @@ -205,7 +205,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, const double time_start = hwy::platform::Now(); GenerateGemma(model, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, abs_pos, pool, inner_pool, + args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); const double time_end = hwy::platform::Now(); const double tok_sec = current_pos / (time_end - time_start); @@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader, pool); + auto kv_cache = model.CreateKVCache(); if (const char* error = inference.Validate()) { ShowHelp(loader, inference, app); @@ -273,7 +274,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } ReplGemma( - model, pool, inner_pool, inference, app.verbosity, + model, kv_cache, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int) { return true; }, app.eot_line); }