diff --git a/gemma/gemma.cc b/gemma/gemma.cc index b7bde6e..3af2a43 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -120,4 +120,21 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, pools_.StopSpinning(); } +template +struct GetModelConfig { + ModelConfigInfo operator()() const { + return ModelConfigInfo{ + .layers = TConfig::kLayers, + .model_dim = TConfig::kModelDim, + .heads = TConfig::kHeads, + .kv_heads = TConfig::kKVHeads, + .qkv_dim = TConfig::kQKVDim, + }; + } +}; + +ModelConfigInfo Gemma::ModelConfig() const { + return CallForModel(info_.model); +} + } // namespace gcpp diff --git a/gemma/gemma.h b/gemma/gemma.h index 70f3280..ae4d5eb 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -156,6 +156,15 @@ struct TimingInfo { size_t tokens_generated = 0; }; +// ModelConfigInfo holds model configuration details: number of layers, etc. +struct ModelConfigInfo { + const int layers; + const int model_dim; + const int heads; + const int kv_heads; + const int qkv_dim; +}; + class Gemma { public: Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, @@ -166,6 +175,7 @@ class Gemma { PerClusterPools& pools); ~Gemma(); + ModelConfigInfo ModelConfig() const; const ModelInfo& Info() const { return info_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const ByteStorageT& Weights() const { return weights_u8_; }