Make `CreateKVCache` a free function rather than a method

This commit is contained in:
RangerUFO 2024-03-07 14:08:48 +08:00
parent b841612e8c
commit 170a9b4690
3 changed files with 28 additions and 20 deletions

View File

@ -231,7 +231,6 @@ struct Activations {
struct GemmaInterface {
virtual ~GemmaInterface() = default;
virtual KVCache CreateKVCache() const = 0;
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
@ -243,6 +242,23 @@ struct GemmaInterface {
int verbosity) = 0;
};
template <class Config>
KVCache CreateKVCache() {
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
Config::kSeqLen);
}
KVCache CreateKVCache(Model type) {
switch (type) {
case Model::GEMMA_2B:
return CreateKVCache<ConfigGemma2B>();
case Model::GEMMA_7B:
return CreateKVCache<ConfigGemma7B>();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
}
}
template <class Config>
struct GemmaImpl : public GemmaInterface {
GemmaImpl( // const LoaderArgs& args,
@ -256,8 +272,6 @@ struct GemmaImpl : public GemmaInterface {
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
}
KVCache CreateKVCache() const override;
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
return tokenizer.get();
}
@ -739,6 +753,13 @@ HWY_EXPORT(GetCompressedWeightsT);
HWY_EXPORT(Generate2B);
HWY_EXPORT(Generate7B);
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
KVCache kv_cache = {};
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
return kv_cache;
}
template <class Config>
GemmaImpl<Config>::GemmaImpl(
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
@ -755,17 +776,6 @@ GemmaImpl<Config>::GemmaImpl(
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
}
template <class Config>
KVCache GemmaImpl<Config>::CreateKVCache() const {
constexpr const size_t size_cache_pos = Config::kLayers * Config::kKVHeads *
Config::kQKVDim;
constexpr const size_t seq_len = Config::kSeqLen;
KVCache kv_cache = {};
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
return kv_cache;
}
template <>
void GemmaImpl<ConfigGemma2B>::Generate(
size_t max_tokens, size_t max_generated_tokens, float temperature,
@ -815,10 +825,6 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
}
Gemma::~Gemma() = default; // after GemmaInterface is defined
KVCache Gemma::CreateKVCache() const {
return impl_->CreateKVCache();
}
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
return impl_->Tokenizer();
}

View File

@ -157,13 +157,15 @@ struct Gemma {
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
KVCache CreateKVCache() const;
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
gcpp::ModelTraining model_training;
};
KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f.
using StreamFunc = std::function<bool(int, float)>;

2
run.cc
View File

@ -236,7 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
}
gcpp::Gemma model(loader, pool);
auto kv_cache = model.CreateKVCache();
auto kv_cache = CreateKVCache(loader.ModelType());
if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app);