mirror of https://github.com/google/gemma.cpp.git
Separate KV cache from GemmaImpl
This commit is contained in:
parent
6c0388e049
commit
b841612e8c
89
gemma.cc
89
gemma.cc
|
|
@ -231,12 +231,13 @@ struct Activations {
|
||||||
struct GemmaInterface {
|
struct GemmaInterface {
|
||||||
virtual ~GemmaInterface() = default;
|
virtual ~GemmaInterface() = default;
|
||||||
|
|
||||||
|
virtual KVCache CreateKVCache() const = 0;
|
||||||
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
|
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
|
||||||
|
|
||||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||||
float temperature, const std::vector<int>& prompt,
|
float temperature, const std::vector<int>& prompt,
|
||||||
size_t start_pos, hwy::ThreadPool& pool,
|
size_t start_pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
const StreamFunc& stream_token,
|
const StreamFunc& stream_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
int verbosity) = 0;
|
int verbosity) = 0;
|
||||||
|
|
@ -255,22 +256,24 @@ struct GemmaImpl : public GemmaInterface {
|
||||||
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
||||||
}
|
}
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const {
|
KVCache CreateKVCache() const override;
|
||||||
|
|
||||||
|
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
|
||||||
return tokenizer.get();
|
return tokenizer.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||||
float temperature, const std::vector<int>& prompt,
|
float temperature, const std::vector<int>& prompt,
|
||||||
size_t start_pos, hwy::ThreadPool& pool,
|
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
|
const AcceptFunc& accept_token, std::mt19937&,
|
||||||
|
int verbosity) override;
|
||||||
|
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
||||||
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
||||||
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
||||||
KVCache kv_cache;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -503,7 +506,7 @@ void Transformer(int token, size_t pos,
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
size_t max_generated_tokens, float temperature,
|
size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t pos,
|
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
const StreamFunc& stream_token,
|
const StreamFunc& stream_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
|
|
@ -517,7 +520,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
const CompressedWeights<TConfig>& c_weights =
|
const CompressedWeights<TConfig>& c_weights =
|
||||||
*reinterpret_cast<CompressedWeights<TConfig>*>(
|
*reinterpret_cast<CompressedWeights<TConfig>*>(
|
||||||
gemma.compressed_weights.get());
|
gemma.compressed_weights.get());
|
||||||
KVCache& kv_cache = gemma.kv_cache;
|
|
||||||
int token;
|
int token;
|
||||||
|
|
||||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
||||||
|
|
@ -612,23 +614,25 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||||
size_t max_generated_tokens, float temperature,
|
size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||||
std::mt19937& gen, int verbosity) {
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
|
int verbosity) {
|
||||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||||
start_pos, pool, inner_pool, stream_token, accept_token, gen,
|
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||||
verbosity);
|
accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||||
size_t max_generated_tokens, float temperature,
|
size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t start_pos,
|
const std::vector<int>& prompt, size_t start_pos,
|
||||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||||
std::mt19937& gen, int verbosity) {
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
|
int verbosity) {
|
||||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||||
start_pos, pool, inner_pool, stream_token, accept_token, gen,
|
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||||
verbosity);
|
accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||||
|
|
@ -735,13 +739,6 @@ HWY_EXPORT(GetCompressedWeightsT);
|
||||||
HWY_EXPORT(Generate2B);
|
HWY_EXPORT(Generate2B);
|
||||||
HWY_EXPORT(Generate7B);
|
HWY_EXPORT(Generate7B);
|
||||||
|
|
||||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
|
||||||
KVCache kv_cache = {};
|
|
||||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
|
||||||
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
|
||||||
return kv_cache;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class Config>
|
template <class Config>
|
||||||
GemmaImpl<Config>::GemmaImpl(
|
GemmaImpl<Config>::GemmaImpl(
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||||
|
|
@ -753,33 +750,43 @@ GemmaImpl<Config>::GemmaImpl(
|
||||||
// HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
|
// HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
|
||||||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
|
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
|
||||||
kv_cache(
|
|
||||||
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
|
||||||
Config::kSeqLen)),
|
|
||||||
tokenizer(std::move(tokenizer)) {
|
tokenizer(std::move(tokenizer)) {
|
||||||
// PROFILER_ZONE("Startup.tokenizer");
|
// PROFILER_ZONE("Startup.tokenizer");
|
||||||
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
|
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class Config>
|
||||||
|
KVCache GemmaImpl<Config>::CreateKVCache() const {
|
||||||
|
constexpr const size_t size_cache_pos = Config::kLayers * Config::kKVHeads *
|
||||||
|
Config::kQKVDim;
|
||||||
|
constexpr const size_t seq_len = Config::kSeqLen;
|
||||||
|
KVCache kv_cache = {};
|
||||||
|
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||||
|
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||||
|
return kv_cache;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void GemmaImpl<ConfigGemma2B>::Generate(
|
void GemmaImpl<ConfigGemma2B>::Generate(
|
||||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
|
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
|
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||||
|
std::mt19937& gen, int verbosity) {
|
||||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||||
pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
void GemmaImpl<ConfigGemma7B>::Generate(
|
void GemmaImpl<ConfigGemma7B>::Generate(
|
||||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||||
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
|
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
|
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||||
|
std::mt19937& gen, int verbosity) {
|
||||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||||
pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Make Gemma type independent of LoaderArgs, create a factory function
|
// TODO: Make Gemma type independent of LoaderArgs, create a factory function
|
||||||
|
|
@ -808,20 +815,24 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||||
}
|
}
|
||||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||||
|
|
||||||
|
KVCache Gemma::CreateKVCache() const {
|
||||||
|
return impl_->CreateKVCache();
|
||||||
|
}
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
||||||
return impl_->Tokenizer();
|
return impl_->Tokenizer();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||||
float temperature, const std::vector<int>& prompt,
|
float temperature, const std::vector<int>& prompt,
|
||||||
size_t start_pos, hwy::ThreadPool& pool,
|
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
int verbosity) {
|
int verbosity) {
|
||||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
||||||
start_pos, pool, inner_pool, stream_token, accept_token,
|
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||||
gen, verbosity);
|
accept_token, gen, verbosity);
|
||||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
3
gemma.h
3
gemma.h
|
|
@ -157,6 +157,7 @@ struct Gemma {
|
||||||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||||
|
|
||||||
|
KVCache CreateKVCache() const;
|
||||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||||
|
|
||||||
std::unique_ptr<GemmaInterface> impl_;
|
std::unique_ptr<GemmaInterface> impl_;
|
||||||
|
|
@ -211,7 +212,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
|
|
||||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||||
float temperature, const std::vector<int>& prompt,
|
float temperature, const std::vector<int>& prompt,
|
||||||
size_t start_pos, hwy::ThreadPool& pool,
|
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||||
int verbosity);
|
int verbosity);
|
||||||
|
|
|
||||||
13
run.cc
13
run.cc
|
|
@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
||||||
std::cerr << "\n";
|
std::cerr << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
||||||
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||||
int verbosity, const gcpp::AcceptFunc& accept_token,
|
const InferenceArgs& args, int verbosity,
|
||||||
std::string& eot_line) {
|
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
|
||||||
PROFILER_ZONE("Gen.misc");
|
PROFILER_ZONE("Gen.misc");
|
||||||
int abs_pos = 0; // absolute token index over all turns
|
int abs_pos = 0; // absolute token index over all turns
|
||||||
int current_pos = 0; // token index within the current turn
|
int current_pos = 0; // token index within the current turn
|
||||||
|
|
@ -205,7 +205,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||||
|
|
||||||
const double time_start = hwy::platform::Now();
|
const double time_start = hwy::platform::Now();
|
||||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||||
args.temperature, prompt, abs_pos, pool, inner_pool,
|
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
|
||||||
stream_token, accept_token, gen, verbosity);
|
stream_token, accept_token, gen, verbosity);
|
||||||
const double time_end = hwy::platform::Now();
|
const double time_end = hwy::platform::Now();
|
||||||
const double tok_sec = current_pos / (time_end - time_start);
|
const double tok_sec = current_pos / (time_end - time_start);
|
||||||
|
|
@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model(loader, pool);
|
gcpp::Gemma model(loader, pool);
|
||||||
|
auto kv_cache = model.CreateKVCache();
|
||||||
|
|
||||||
if (const char* error = inference.Validate()) {
|
if (const char* error = inference.Validate()) {
|
||||||
ShowHelp(loader, inference, app);
|
ShowHelp(loader, inference, app);
|
||||||
|
|
@ -273,7 +274,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplGemma(
|
ReplGemma(
|
||||||
model, pool, inner_pool, inference, app.verbosity,
|
model, kv_cache, pool, inner_pool, inference, app.verbosity,
|
||||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue