Merge pull request #81 from ufownl/feature/separated_kvcache

Separate KV cache from GemmaImpl
This commit is contained in:
Austin Huang 2024-03-07 10:10:11 -05:00 committed by GitHub
commit 3df06f64c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 39 deletions

View File

@ -235,13 +235,30 @@ struct GemmaInterface {
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool,
size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) = 0;
};
template <class Config>
KVCache CreateKVCache() {
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen);
}
KVCache CreateKVCache(Model type) {
switch (type) {
case Model::GEMMA_2B:
return CreateKVCache<ConfigGemma2B>();
case Model::GEMMA_7B:
return CreateKVCache<ConfigGemma7B>();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
}
}
template <class Config>
struct GemmaImpl : public GemmaInterface {
GemmaImpl( // const LoaderArgs& args,
@ -255,22 +272,22 @@ struct GemmaImpl : public GemmaInterface {
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
}
const sentencepiece::SentencePieceProcessor* Tokenizer() const {
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
return tokenizer.get();
}
void Generate(size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937&, int verbosity);
const AcceptFunc& accept_token, std::mt19937&,
int verbosity) override;
std::unique_ptr<sentencepiece::SentencePieceProcessor> tokenizer;
hwy::AlignedFreeUniquePtr<uint8_t[]> compressed_weights;
hwy::AlignedUniquePtr<Activations<Config, kPrefillBatchSize>> prefill;
hwy::AlignedUniquePtr<Activations<Config, 1>> state;
KVCache kv_cache;
};
} // namespace gcpp
@ -503,7 +520,7 @@ void Transformer(int token, size_t pos,
template <class TConfig>
void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t pos,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
@ -517,7 +534,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
const CompressedWeights<TConfig>& c_weights =
*reinterpret_cast<CompressedWeights<TConfig>*>(
gemma.compressed_weights.get());
KVCache& kv_cache = gemma.kv_cache;
int token;
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
@ -612,23 +628,25 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
void Generate2B(GemmaImpl<ConfigGemma2B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, pool, inner_pool, stream_token, accept_token, gen,
verbosity);
start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
}
void Generate7B(GemmaImpl<ConfigGemma7B>& gemma, size_t max_tokens,
size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt,
start_pos, pool, inner_pool, stream_token, accept_token, gen,
verbosity);
start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
}
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
@ -753,9 +771,6 @@ GemmaImpl<Config>::GemmaImpl(
// HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)),
prefill(hwy::MakeUniqueAligned<Activations<Config, kPrefillBatchSize>>()),
state(hwy::MakeUniqueAligned<Activations<Config, 1>>()),
kv_cache(
CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen)),
tokenizer(std::move(tokenizer)) {
// PROFILER_ZONE("Startup.tokenizer");
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
@ -764,22 +779,24 @@ GemmaImpl<Config>::GemmaImpl(
template <>
void GemmaImpl<ConfigGemma2B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate2B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
pool, inner_pool, stream_token, accept_token, gen, verbosity);
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
}
template <>
void GemmaImpl<ConfigGemma7B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
const std::vector<int>& prompt, size_t start_pos, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) {
const std::vector<int>& prompt, size_t start_pos, KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const StreamFunc& stream_token, const AcceptFunc& accept_token,
std::mt19937& gen, int verbosity) {
HWY_DYNAMIC_DISPATCH(Generate7B)
(*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos,
pool, inner_pool, stream_token, accept_token, gen, verbosity);
kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity);
}
// TODO: Make Gemma type independent of LoaderArgs, create a factory function
@ -814,14 +831,14 @@ const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity) {
pool.SetWaitMode(hwy::PoolWaitMode::kSpin);
gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt,
start_pos, pool, inner_pool, stream_token, accept_token,
gen, verbosity);
start_pos, kv_cache, pool, inner_pool, stream_token,
accept_token, gen, verbosity);
pool.SetWaitMode(hwy::PoolWaitMode::kBlock);
}

View File

@ -163,6 +163,9 @@ struct Gemma {
gcpp::ModelTraining model_training;
};
KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;
@ -211,7 +214,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens,
float temperature, const std::vector<int>& prompt,
size_t start_pos, hwy::ThreadPool& pool,
size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const StreamFunc& stream_token,
const AcceptFunc& accept_token, std::mt19937& gen,
int verbosity);

13
run.cc
View File

@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
std::cerr << "\n";
}
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
int verbosity, const gcpp::AcceptFunc& accept_token,
std::string& eot_line) {
void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool,
const InferenceArgs& args, int verbosity,
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn
@ -205,7 +205,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
const double time_start = hwy::platform::Now();
GenerateGemma(model, args.max_tokens, args.max_generated_tokens,
args.temperature, prompt, abs_pos, pool, inner_pool,
args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool,
stream_token, accept_token, gen, verbosity);
const double time_end = hwy::platform::Now();
const double tok_sec = current_pos / (time_end - time_start);
@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
}
gcpp::Gemma model(loader, pool);
auto kv_cache = CreateKVCache(loader.ModelType());
if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app);
@ -273,7 +274,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
}
ReplGemma(
model, pool, inner_pool, inference, app.verbosity,
model, kv_cache, pool, inner_pool, inference, app.verbosity,
/*accept_token=*/[](int) { return true; }, app.eot_line);
}