mirror of https://github.com/google/gemma.cpp.git
Make `CreateKVCache` a free function rather than a method
This commit is contained in:
parent
b841612e8c
commit
170a9b4690
42
gemma.cc
42
gemma.cc
|
|
@ -231,7 +231,6 @@ struct Activations {
|
|||
struct GemmaInterface {
|
||||
virtual ~GemmaInterface() = default;
|
||||
|
||||
virtual KVCache CreateKVCache() const = 0;
|
||||
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
|
||||
|
||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||
|
|
@ -243,6 +242,23 @@ struct GemmaInterface {
|
|||
int verbosity) = 0;
|
||||
};
|
||||
|
||||
template <class Config>
|
||||
KVCache CreateKVCache() {
|
||||
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
||||
Config::kSeqLen);
|
||||
}
|
||||
|
||||
KVCache CreateKVCache(Model type) {
|
||||
switch (type) {
|
||||
case Model::GEMMA_2B:
|
||||
return CreateKVCache<ConfigGemma2B>();
|
||||
case Model::GEMMA_7B:
|
||||
return CreateKVCache<ConfigGemma7B>();
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
|
||||
}
|
||||
}
|
||||
|
||||
template <class Config>
|
||||
struct GemmaImpl : public GemmaInterface {
|
||||
GemmaImpl( // const LoaderArgs& args,
|
||||
|
|
@ -256,8 +272,6 @@ struct GemmaImpl : public GemmaInterface {
|
|||
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
||||
}
|
||||
|
||||
KVCache CreateKVCache() const override;
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
|
||||
return tokenizer.get();
|
||||
}
|
||||
|
|
@ -739,6 +753,13 @@ HWY_EXPORT(GetCompressedWeightsT);
|
|||
HWY_EXPORT(Generate2B);
|
||||
HWY_EXPORT(Generate7B);
|
||||
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
||||
KVCache kv_cache = {};
|
||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
return kv_cache;
|
||||
}
|
||||
|
||||
template <class Config>
|
||||
GemmaImpl<Config>::GemmaImpl(
|
||||
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||
|
|
@ -755,17 +776,6 @@ GemmaImpl<Config>::GemmaImpl(
|
|||
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
|
||||
}
|
||||
|
||||
template <class Config>
|
||||
KVCache GemmaImpl<Config>::CreateKVCache() const {
|
||||
constexpr const size_t size_cache_pos = Config::kLayers * Config::kKVHeads *
|
||||
Config::kQKVDim;
|
||||
constexpr const size_t seq_len = Config::kSeqLen;
|
||||
KVCache kv_cache = {};
|
||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
return kv_cache;
|
||||
}
|
||||
|
||||
template <>
|
||||
void GemmaImpl<ConfigGemma2B>::Generate(
|
||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||
|
|
@ -815,10 +825,6 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
|||
}
|
||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||
|
||||
KVCache Gemma::CreateKVCache() const {
|
||||
return impl_->CreateKVCache();
|
||||
}
|
||||
|
||||
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
||||
return impl_->Tokenizer();
|
||||
}
|
||||
|
|
|
|||
4
gemma.h
4
gemma.h
|
|
@ -157,13 +157,15 @@ struct Gemma {
|
|||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||
|
||||
KVCache CreateKVCache() const;
|
||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
gcpp::ModelTraining model_training;
|
||||
};
|
||||
|
||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
|
|
|
|||
2
run.cc
2
run.cc
|
|
@ -236,7 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
}
|
||||
|
||||
gcpp::Gemma model(loader, pool);
|
||||
auto kv_cache = model.CreateKVCache();
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
|
||||
if (const char* error = inference.Validate()) {
|
||||
ShowHelp(loader, inference, app);
|
||||
|
|
|
|||
Loading…
Reference in New Issue