Separate KV cache from GemmaImpl

This commit is contained in:
RangerUFO 2024-03-05 17:50:24 +08:00
parent 6c0388e049
commit b841612e8c
3 changed files with 59 additions and 46 deletions

View File

@ -231,12 +231,13 @@ struct Activations {
struct GemmaInterface { struct GemmaInterface {
virtual ~GemmaInterface() = default; virtual ~GemmaInterface() = default;
virtual KVCache CreateKVCache() const = 0;
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0; virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
virtual void Generate(size_t max_tokens, size_t max_generated_tokens, virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt, float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) = 0; int verbosity) = 0;
@ -255,22 +256,24 @@ struct GemmaImpl : public GemmaInterface {
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>(); c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
} }
const sentencepiece::SentencePieceProcessor* Tokenizer() const { KVCache CreateKVCache() const override;
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
return tokenizer.get(); return tokenizer.get();
} }
void Generate(size_t max_tokens, size_t max_generated_tokens, void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt, float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937&, int verbosity); const AcceptFunc& accept_token, std::mt19937&,
int verbosity) override;
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer; std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights; hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill; hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
hwy::AlignedUniquePtr<Activations<Config, 1>> state; hwy::AlignedUniquePtr<Activations<Config, 1>> state;
KVCache kv_cache;
}; };
} // namespace gcpp } // namespace gcpp
@ -503,7 +506,7 @@ void Transformer(int token, size_t pos,
template <class TConfig> template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens, void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t pos, const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
@ -517,7 +520,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
const CompressedWeights<TConfig>& c_weights = const CompressedWeights<TConfig>& c_weights =
*reinterpret_cast<CompressedWeights<TConfig>*>( *reinterpret_cast<CompressedWeights<TConfig>*>(
gemma.compressed_weights.get()); gemma.compressed_weights.get());
KVCache& kv_cache = gemma.kv_cache;
int token; int token;
// pos indexes the KV cache. In the first turn of a chat, pos = 0. // pos indexes the KV cache. In the first turn of a chat, pos = 0.
@ -612,23 +614,25 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens, void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
std::mt19937& gen, int verbosity) { const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, pool, inner_pool, stream_token, accept_token, gen, start_pos, kv_cache, pool, inner_pool, stream_token,
verbosity); accept_token, gen, verbosity);
} }
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens, void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, KVCache& kv_cache, hwy::ThreadPool& pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
std::mt19937& gen, int verbosity) { const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, pool, inner_pool, stream_token, accept_token, gen, start_pos, kv_cache, pool, inner_pool, stream_token,
verbosity); accept_token, gen, verbosity);
} }
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null // Calls func(name, float*, CompressedArray&) for each tensor. float* is null
@ -735,13 +739,6 @@ HWY_EXPORT(GetCompressedWeightsT);
HWY_EXPORT(Generate2B); HWY_EXPORT(Generate2B);
HWY_EXPORT(Generate7B); HWY_EXPORT(Generate7B);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
KVCache kv_cache = {};
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
return kv_cache;
}
template <class Config> template <class Config>
GemmaImpl<Config>::GemmaImpl( GemmaImpl<Config>::GemmaImpl(
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer, std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
@ -753,33 +750,43 @@ GemmaImpl<Config>::GemmaImpl(
// HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), // HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()), prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()), state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
kv_cache(
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen)),
tokenizer(std::move(tokenizer)) { tokenizer(std::move(tokenizer)) {
// PROFILER_ZONE("Startup.tokenizer"); // PROFILER_ZONE("Startup.tokenizer");
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); // HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
} }
template <class Config>
KVCache GemmaImpl<Config>::CreateKVCache() const {
constexpr const size_t size_cache_pos = Config::kLayers * Config::kKVHeads *
Config::kQKVDim;
constexpr const size_t seq_len = Config::kSeqLen;
KVCache kv_cache = {};
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
return kv_cache;
}
template <> template <>
void GemmaImpl<ConfigGemma2B>::Generate( void GemmaImpl<ConfigGemma2B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature, size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate2B) HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
pool, inner_pool, stream_token, accept_token, gen, verbosity); kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
} }
template <> template <>
void GemmaImpl<ConfigGemma7B>::Generate( void GemmaImpl<ConfigGemma7B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature, size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool, const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate7B) HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
pool, inner_pool, stream_token, accept_token, gen, verbosity); kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
} }
// TODO: Make Gemma type independent of LoaderArgs, create a factory function // TODO: Make Gemma type independent of LoaderArgs, create a factory function
@ -808,20 +815,24 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
} }
Gemma::~Gemma() = default; // after GemmaInterface is defined Gemma::~Gemma() = default; // after GemmaInterface is defined
KVCache Gemma::CreateKVCache() const {
return impl_->CreateKVCache();
}
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
return impl_->Tokenizer(); return impl_->Tokenizer();
} }
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt, float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) { int verbosity) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin); pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
start_pos, pool, inner_pool, stream_token, accept_token, start_pos, kv_cache, pool, inner_pool, stream_token,
gen, verbosity); accept_token, gen, verbosity);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock); pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
} }

View File

@ -157,6 +157,7 @@ struct Gemma {
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined. ~Gemma(); // must be defined after GemmaInterface's dtor is defined.
KVCache CreateKVCache() const;
const sentencepiece::SentencePieceProcessor* Tokenizer() const; const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_; std::unique_ptr<GemmaInterface> impl_;
@ -211,7 +212,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt, float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity); int verbosity);

13
run.cc
View File

@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
std::cerr << "\n"; std::cerr << "\n";
} }
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
hwy::ThreadPool& inner_pool, const InferenceArgs& args, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
int verbosity, const gcpp::AcceptFunc& accept_token, const InferenceArgs& args, int verbosity,
std::string& eot_line) { const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns int abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn int current_pos = 0; // token index within the current turn
@ -205,7 +205,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
const double time_start = hwy::platform::Now(); const double time_start = hwy::platform::Now();
GenerateGemma(model, args.max_tokens, args.max_generated_tokens, GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, abs_pos, pool, inner_pool, args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
stream_token, accept_token, gen, verbosity); stream_token, accept_token, gen, verbosity);
const double time_end = hwy::platform::Now(); const double time_end = hwy::platform::Now();
const double tok_sec = current_pos / (time_end - time_start); const double tok_sec = current_pos / (time_end - time_start);
@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
} }
gcpp::Gemma model(loader, pool); gcpp::Gemma model(loader, pool);
auto kv_cache = model.CreateKVCache();
if (const char* error = inference.Validate()) { if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app); ShowHelp(loader, inference, app);
@ -273,7 +274,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
} }
ReplGemma( ReplGemma(
model, pool, inner_pool, inference, app.verbosity, model, kv_cache, pool, inner_pool, inference, app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line); /*accept_token=*/[](int) { return true; }, app.eot_line);
} }