diff --git a/gemma.cc b/gemma.cc index bbd86c3..ba6aafa 100644 --- a/gemma.cc +++ b/gemma.cc @@ -235,13 +235,30 @@ struct GemmaInterface { virtual void Generate(size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, + size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) = 0; }; +template +KVCache CreateKVCache() { + return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, + Config::kSeqLen); +} + +KVCache CreateKVCache(Model type) { + switch (type) { + case Model::GEMMA_2B: + return CreateKVCache(); + case Model::GEMMA_7B: + return CreateKVCache(); + default: + HWY_ABORT("Model type %d unknown.", static_cast(type)); + } +} + template struct GemmaImpl : public GemmaInterface { GemmaImpl( // const LoaderArgs& args, @@ -255,22 +272,22 @@ struct GemmaImpl : public GemmaInterface { c_weights->c_layer_ptrs.~CompressedLayerPointers(); } - const sentencepiece::SentencePieceProcessor* Tokenizer() const { + const sentencepiece::SentencePieceProcessor* Tokenizer() const override { return tokenizer.get(); } void Generate(size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937&, int verbosity); + const AcceptFunc& accept_token, std::mt19937&, + int verbosity) override; std::unique_ptr tokenizer; hwy::AlignedFreeUniquePtr compressed_weights; hwy::AlignedUniquePtr> prefill; hwy::AlignedUniquePtr> state; - KVCache kv_cache; }; } // namespace gcpp @@ -503,7 +520,7 @@ void Transformer(int token, size_t pos, template void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t pos, + const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, @@ -517,7 +534,6 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, const CompressedWeights& c_weights = *reinterpret_cast*>( gemma.compressed_weights.get()); - KVCache& kv_cache = gemma.kv_cache; int token; // pos indexes the KV cache. In the first turn of a chat, pos = 0. @@ -612,23 +628,25 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, void Generate2B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, pool, inner_pool, stream_token, accept_token, gen, - verbosity); + start_pos, kv_cache, pool, inner_pool, stream_token, + accept_token, gen, verbosity); } void Generate7B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, pool, inner_pool, stream_token, accept_token, gen, - verbosity); + start_pos, kv_cache, pool, inner_pool, stream_token, + accept_token, gen, verbosity); } // Calls func(name, float*, CompressedArray&) for each tensor. float* is null @@ -753,9 +771,6 @@ GemmaImpl::GemmaImpl( // HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), prefill(hwy::MakeUniqueAligned>()), state(hwy::MakeUniqueAligned>()), - kv_cache( - CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, - Config::kSeqLen)), tokenizer(std::move(tokenizer)) { // PROFILER_ZONE("Startup.tokenizer"); // HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); @@ -764,22 +779,24 @@ GemmaImpl::GemmaImpl( template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { + const std::vector& prompt, size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate2B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - pool, inner_pool, stream_token, accept_token, gen, verbosity); + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { + const std::vector& prompt, size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate7B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - pool, inner_pool, stream_token, accept_token, gen, verbosity); + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } // TODO: Make Gemma type independent of LoaderArgs, create a factory function @@ -814,14 +831,14 @@ const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, - start_pos, pool, inner_pool, stream_token, accept_token, - gen, verbosity); + start_pos, kv_cache, pool, inner_pool, stream_token, + accept_token, gen, verbosity); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } diff --git a/gemma.h b/gemma.h index f6361e1..58fd74a 100644 --- a/gemma.h +++ b/gemma.h @@ -163,6 +163,9 @@ struct Gemma { gcpp::ModelTraining model_training; }; +KVCache CreateKVCache(Model type); // convenient workaround for now +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); + // StreamFunc is called with (token, probability). For prompt tokens, // probability is 0.0f. using StreamFunc = std::function; @@ -211,7 +214,7 @@ struct InferenceArgs : public ArgsBase { void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity); diff --git a/run.cc b/run.cc index 4c4f132..40be63e 100644 --- a/run.cc +++ b/run.cc @@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, std::cerr << "\n"; } -void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const InferenceArgs& args, - int verbosity, const gcpp::AcceptFunc& accept_token, - std::string& eot_line) { +void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const InferenceArgs& args, int verbosity, + const gcpp::AcceptFunc& accept_token, std::string& eot_line) { PROFILER_ZONE("Gen.misc"); int abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn @@ -205,7 +205,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, const double time_start = hwy::platform::Now(); GenerateGemma(model, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, abs_pos, pool, inner_pool, + args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); const double time_end = hwy::platform::Now(); const double tok_sec = current_pos / (time_end - time_start); @@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader, pool); + auto kv_cache = CreateKVCache(loader.ModelType()); if (const char* error = inference.Validate()) { ShowHelp(loader, inference, app); @@ -273,7 +274,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } ReplGemma( - model, pool, inner_pool, inference, app.verbosity, + model, kv_cache, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int) { return true; }, app.eot_line); }