mirror of https://github.com/google/gemma.cpp.git
Make `CreateKVCache` a free function rather than a method
This commit is contained in:
parent
b841612e8c
commit
170a9b4690
42
gemma.cc
42
gemma.cc
|
|
@ -231,7 +231,6 @@ struct Activations {
|
||||||
struct GemmaInterface {
|
struct GemmaInterface {
|
||||||
virtual ~GemmaInterface() = default;
|
virtual ~GemmaInterface() = default;
|
||||||
|
|
||||||
virtual KVCache CreateKVCache() const = 0;
|
|
||||||
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
|
virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0;
|
||||||
|
|
||||||
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
virtual void Generate(size_t max_tokens, size_t max_generated_tokens,
|
||||||
|
|
@ -243,6 +242,23 @@ struct GemmaInterface {
|
||||||
int verbosity) = 0;
|
int verbosity) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <class Config>
|
||||||
|
KVCache CreateKVCache() {
|
||||||
|
return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim,
|
||||||
|
Config::kSeqLen);
|
||||||
|
}
|
||||||
|
|
||||||
|
KVCache CreateKVCache(Model type) {
|
||||||
|
switch (type) {
|
||||||
|
case Model::GEMMA_2B:
|
||||||
|
return CreateKVCache<ConfigGemma2B>();
|
||||||
|
case Model::GEMMA_7B:
|
||||||
|
return CreateKVCache<ConfigGemma7B>();
|
||||||
|
default:
|
||||||
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(type));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <class Config>
|
template <class Config>
|
||||||
struct GemmaImpl : public GemmaInterface {
|
struct GemmaImpl : public GemmaInterface {
|
||||||
GemmaImpl( // const LoaderArgs& args,
|
GemmaImpl( // const LoaderArgs& args,
|
||||||
|
|
@ -256,8 +272,6 @@ struct GemmaImpl : public GemmaInterface {
|
||||||
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
c_weights->c_layer_ptrs.~CompressedLayerPointers<Config>();
|
||||||
}
|
}
|
||||||
|
|
||||||
KVCache CreateKVCache() const override;
|
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
|
const sentencepiece::SentencePieceProcessor* Tokenizer() const override {
|
||||||
return tokenizer.get();
|
return tokenizer.get();
|
||||||
}
|
}
|
||||||
|
|
@ -739,6 +753,13 @@ HWY_EXPORT(GetCompressedWeightsT);
|
||||||
HWY_EXPORT(Generate2B);
|
HWY_EXPORT(Generate2B);
|
||||||
HWY_EXPORT(Generate7B);
|
HWY_EXPORT(Generate7B);
|
||||||
|
|
||||||
|
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
|
||||||
|
KVCache kv_cache = {};
|
||||||
|
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||||
|
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||||
|
return kv_cache;
|
||||||
|
}
|
||||||
|
|
||||||
template <class Config>
|
template <class Config>
|
||||||
GemmaImpl<Config>::GemmaImpl(
|
GemmaImpl<Config>::GemmaImpl(
|
||||||
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
std::unique_ptr<sentencepiece::SentencePieceProcessor>& tokenizer,
|
||||||
|
|
@ -755,17 +776,6 @@ GemmaImpl<Config>::GemmaImpl(
|
||||||
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
|
// HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Config>
|
|
||||||
KVCache GemmaImpl<Config>::CreateKVCache() const {
|
|
||||||
constexpr const size_t size_cache_pos = Config::kLayers * Config::kKVHeads *
|
|
||||||
Config::kQKVDim;
|
|
||||||
constexpr const size_t seq_len = Config::kSeqLen;
|
|
||||||
KVCache kv_cache = {};
|
|
||||||
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
|
||||||
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
|
||||||
return kv_cache;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void GemmaImpl<ConfigGemma2B>::Generate(
|
void GemmaImpl<ConfigGemma2B>::Generate(
|
||||||
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
size_t max_tokens, size_t max_generated_tokens, float temperature,
|
||||||
|
|
@ -815,10 +825,6 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) {
|
||||||
}
|
}
|
||||||
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
Gemma::~Gemma() = default; // after GemmaInterface is defined
|
||||||
|
|
||||||
KVCache Gemma::CreateKVCache() const {
|
|
||||||
return impl_->CreateKVCache();
|
|
||||||
}
|
|
||||||
|
|
||||||
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {
|
||||||
return impl_->Tokenizer();
|
return impl_->Tokenizer();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
4
gemma.h
4
gemma.h
|
|
@ -157,13 +157,15 @@ struct Gemma {
|
||||||
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
Gemma(const LoaderArgs& args, hwy::ThreadPool& pool);
|
||||||
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
|
||||||
|
|
||||||
KVCache CreateKVCache() const;
|
|
||||||
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
|
||||||
|
|
||||||
std::unique_ptr<GemmaInterface> impl_;
|
std::unique_ptr<GemmaInterface> impl_;
|
||||||
gcpp::ModelTraining model_training;
|
gcpp::ModelTraining model_training;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||||
|
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len);
|
||||||
|
|
||||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||||
// probability is 0.0f.
|
// probability is 0.0f.
|
||||||
using StreamFunc = std::function<bool(int, float)>;
|
using StreamFunc = std::function<bool(int, float)>;
|
||||||
|
|
|
||||||
2
run.cc
2
run.cc
|
|
@ -236,7 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model(loader, pool);
|
gcpp::Gemma model(loader, pool);
|
||||||
auto kv_cache = model.CreateKVCache();
|
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||||
|
|
||||||
if (const char* error = inference.Validate()) {
|
if (const char* error = inference.Validate()) {
|
||||||
ShowHelp(loader, inference, app);
|
ShowHelp(loader, inference, app);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue