mirror of https://github.com/google/gemma.cpp.git
Separate KV cache from GemmaImpl
This commit is contained in:
parent
6c0388e049
commit
b841612e8c
89
gemma.cc
89
gemma.cc
|
|
@ -231,12 +231,13 @@ struct Activations {
|
|||
struct GemmaInterface {
|
||||
virtual ~GemmaInterface() = default;
|
||||
|
||||
virtual KVCache CreateKVCache() const = 0;
|
||||
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
|
||||
|
||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) = 0;
|
||||
|
|
@ -255,22 +256,24 @@ struct GemmaImpl : public GemmaInterface {
|
|||
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
||||
}
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const {
|
||||
KVCache CreateKVCache() const override;
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
|
||||
return tokenizer.get();
|
||||
}
|
||||
|
||||
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
|
||||
const AcceptFunc& accept_token, std::mt19937&,
|
||||
int verbosity) override;
|
||||
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
||||
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
||||
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
||||
KVCache kv_cache;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -503,7 +506,7 @@ void Transformer(int token, size_t pos,
|
|||
template <class TConfig>
|
||||
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t pos,
|
||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
|
|
@ -517,7 +520,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
const CompressedWeights<TConfig>& c_weights =
|
||||
*reinterpret_cast<CompressedWeights<TConfig>*>(
|
||||
gemma.compressed_weights.get());
|
||||
KVCache& kv_cache = gemma.kv_cache;
|
||||
int token;
|
||||
|
||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
||||
|
|
@ -612,23 +614,25 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, pool, inner_pool, stream_token, accept_token, gen,
|
||||
verbosity);
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, pool, inner_pool, stream_token, accept_token, gen,
|
||||
verbosity);
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||
|
|
@ -735,13 +739,6 @@ HWY_EXPORT(GetCompressedWeightsT);
|
|||
HWY_EXPORT(Generate2B);
|
||||
HWY_EXPORT(Generate7B);
|
||||
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
||||
KVCache kv_cache = {};
|
||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
return kv_cache;
|
||||
}
|
||||
|
||||
template <class Config>
|
||||
GemmaImpl<Config>::GemmaImpl(
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||
|
|
@ -753,33 +750,43 @@ GemmaImpl<Config>::GemmaImpl(
|
|||
// HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
|
||||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
|
||||
kv_cache(
|
||||
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
||||
Config::kSeqLen)),
|
||||
tokenizer(std::move(tokenizer)) {
|
||||
// PROFILER_ZONE("Startup.tokenizer");
|
||||
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
|
||||
}
|
||||
|
||||
template <class Config>
|
||||
KVCache GemmaImpl<Config>::CreateKVCache() const {
|
||||
constexpr const size_t size_cache_pos = Config::kLayers * Config::kKVHeads *
|
||||
Config::kQKVDim;
|
||||
constexpr const size_t seq_len = Config::kSeqLen;
|
||||
KVCache kv_cache = {};
|
||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
return kv_cache;
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma2B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma7B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
// TODO: Make Gemma type independent of LoaderArgs, create a factory function
|
||||
|
|
@ -808,20 +815,24 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
|||
}
|
||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||
|
||||
KVCache Gemma::CreateKVCache() const {
|
||||
return impl_->CreateKVCache();
|
||||
}
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
||||
return impl_->Tokenizer();
|
||||
}
|
||||
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, pool, inner_pool, stream_token, accept_token,
|
||||
gen, verbosity);
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
|
|
|
|||
3
gemma.h
3
gemma.h
|
|
@ -157,6 +157,7 @@ struct Gemma {
|
|||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||
|
||||
KVCache CreateKVCache() const;
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
|
|
@ -211,7 +212,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity);
|
||||
|
|
|
|||
13
run.cc
13
run.cc
|
|
@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
|||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
||||
int verbosity, const gcpp::AcceptFunc& accept_token,
|
||||
std::string& eot_line) {
|
||||
void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const InferenceArgs& args, int verbosity,
|
||||
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
int abs_pos = 0; // absolute token index over all turns
|
||||
int current_pos = 0; // token index within the current turn
|
||||
|
|
@ -205,7 +205,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
|||
|
||||
const double time_start = hwy::platform::Now();
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, abs_pos, pool, inner_pool,
|
||||
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
|
||||
stream_token, accept_token, gen, verbosity);
|
||||
const double time_end = hwy::platform::Now();
|
||||
const double tok_sec = current_pos / (time_end - time_start);
|
||||
|
|
@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
}
|
||||
|
||||
gcpp::Gemma model(loader, pool);
|
||||
auto kv_cache = model.CreateKVCache();
|
||||
|
||||
if (const char* error = inference.Validate()) {
|
||||
ShowHelp(loader, inference, app);
|
||||
|
|
@ -273,7 +274,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
}
|
||||
|
||||
ReplGemma(
|
||||
model, pool, inner_pool, inference, app.verbosity,
|
||||
model, kv_cache, pool, inner_pool, inference, app.verbosity,
|
||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue