Expose underlying model configuration: number of layers, heads, etc.

PiperOrigin-RevId: 663747853
This commit is contained in:
Paul Chang 2024-08-16 09:02:40 -07:00 committed by Copybara-Service
parent 301dc8067a
commit 773333e5be
2 changed files with 27 additions and 0 deletions

View File

@ -120,4 +120,21 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
pools_.StopSpinning();
}
template <typename TConfig>
struct GetModelConfig {
ModelConfigInfo operator()() const {
return ModelConfigInfo{
.layers = TConfig::kLayers,
.model_dim = TConfig::kModelDim,
.heads = TConfig::kHeads,
.kv_heads = TConfig::kKVHeads,
.qkv_dim = TConfig::kQKVDim,
};
}
};
ModelConfigInfo Gemma::ModelConfig() const {
return CallForModel<float, GetModelConfig>(info_.model);
}
} // namespace gcpp

View File

@ -156,6 +156,15 @@ struct TimingInfo {
size_t tokens_generated = 0;
};
// ModelConfigInfo holds model configuration details: number of layers, etc.
struct ModelConfigInfo {
const int layers;
const int model_dim;
const int heads;
const int kv_heads;
const int qkv_dim;
};
class Gemma {
public:
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
@ -166,6 +175,7 @@ class Gemma {
PerClusterPools& pools);
~Gemma();
ModelConfigInfo ModelConfig() const;
const ModelInfo& Info() const { return info_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; }