mirror of https://github.com/google/gemma.cpp.git
Merge pull request #81 from ufownl/feature/separated_kvcache
Separate KV cache from GemmaImpl
This commit is contained in:
commit
3df06f64c2
81
gemma.cc
81
gemma.cc
|
|
@ -235,13 +235,30 @@ struct GemmaInterface {
|
|||
|
||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool,
|
||||
size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) = 0;
|
||||
};
|
||||
|
||||
template <class Config>
|
||||
KVCache CreateKVCache() {
|
||||
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
||||
Config::kSeqLen);
|
||||
}
|
||||
|
||||
KVCache CreateKVCache(Model type) {
|
||||
switch (type) {
|
||||
case Model::GEMMA_2B:
|
||||
return CreateKVCache<ConfigGemma2B>();
|
||||
case Model::GEMMA_7B:
|
||||
return CreateKVCache<ConfigGemma7B>();
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
|
||||
}
|
||||
}
|
||||
|
||||
template <class Config>
|
||||
struct GemmaImpl : public GemmaInterface {
|
||||
GemmaImpl( // const LoaderArgs& args,
|
||||
|
|
@ -255,22 +272,22 @@ struct GemmaImpl : public GemmaInterface {
|
|||
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
||||
}
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const {
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
|
||||
return tokenizer.get();
|
||||
}
|
||||
|
||||
void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
|
||||
const AcceptFunc& accept_token, std::mt19937&,
|
||||
int verbosity) override;
|
||||
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
|
||||
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
|
||||
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
|
||||
KVCache kv_cache;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -503,7 +520,7 @@ void Transformer(int token, size_t pos,
|
|||
template <class TConfig>
|
||||
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t pos,
|
||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
|
|
@ -517,7 +534,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
const CompressedWeights<TConfig>& c_weights =
|
||||
*reinterpret_cast<CompressedWeights<TConfig>*>(
|
||||
gemma.compressed_weights.get());
|
||||
KVCache& kv_cache = gemma.kv_cache;
|
||||
int token;
|
||||
|
||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
||||
|
|
@ -612,23 +628,25 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, pool, inner_pool, stream_token, accept_token, gen,
|
||||
verbosity);
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
|
||||
size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, pool, inner_pool, stream_token, accept_token, gen,
|
||||
verbosity);
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||
|
|
@ -753,9 +771,6 @@ GemmaImpl<Config>::GemmaImpl(
|
|||
// HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
|
||||
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
|
||||
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
|
||||
kv_cache(
|
||||
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
||||
Config::kSeqLen)),
|
||||
tokenizer(std::move(tokenizer)) {
|
||||
// PROFILER_ZONE("Startup.tokenizer");
|
||||
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
|
||||
|
|
@ -764,22 +779,24 @@ GemmaImpl<Config>::GemmaImpl(
|
|||
template <>
|
||||
void GemmaImpl<ConfigGemma2B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate2B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma7B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
|
||||
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const StreamFunc& stream_token, const AcceptFunc& accept_token,
|
||||
std::mt19937& gen, int verbosity) {
|
||||
HWY_DYNAMIC_DISPATCH(Generate7B)
|
||||
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
|
||||
pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
|
||||
}
|
||||
|
||||
// TODO: Make Gemma type independent of LoaderArgs, create a factory function
|
||||
|
|
@ -814,14 +831,14 @@ const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
|||
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity) {
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
|
||||
start_pos, pool, inner_pool, stream_token, accept_token,
|
||||
gen, verbosity);
|
||||
start_pos, kv_cache, pool, inner_pool, stream_token,
|
||||
accept_token, gen, verbosity);
|
||||
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
|
||||
|
|
|
|||
5
gemma.h
5
gemma.h
|
|
@ -163,6 +163,9 @@ struct Gemma {
|
|||
gcpp::ModelTraining model_training;
|
||||
};
|
||||
|
||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
|
|
@ -211,7 +214,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
|
||||
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
|
||||
float temperature, const std::vector<int>& prompt,
|
||||
size_t start_pos, hwy::ThreadPool& pool,
|
||||
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
|
||||
const AcceptFunc& accept_token, std::mt19937& gen,
|
||||
int verbosity);
|
||||
|
|
|
|||
13
run.cc
13
run.cc
|
|
@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
|||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
||||
int verbosity, const gcpp::AcceptFunc& accept_token,
|
||||
std::string& eot_line) {
|
||||
void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
|
||||
const InferenceArgs& args, int verbosity,
|
||||
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
int abs_pos = 0; // absolute token index over all turns
|
||||
int current_pos = 0; // token index within the current turn
|
||||
|
|
@ -205,7 +205,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
|||
|
||||
const double time_start = hwy::platform::Now();
|
||||
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
|
||||
args.temperature, prompt, abs_pos, pool, inner_pool,
|
||||
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
|
||||
stream_token, accept_token, gen, verbosity);
|
||||
const double time_end = hwy::platform::Now();
|
||||
const double tok_sec = current_pos / (time_end - time_start);
|
||||
|
|
@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
}
|
||||
|
||||
gcpp::Gemma model(loader, pool);
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
|
||||
if (const char* error = inference.Validate()) {
|
||||
ShowHelp(loader, inference, app);
|
||||
|
|
@ -273,7 +274,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
}
|
||||
|
||||
ReplGemma(
|
||||
model, pool, inner_pool, inference, app.verbosity,
|
||||
model, kv_cache, pool, inner_pool, inference, app.verbosity,
|
||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue