mirror of https://github.com/google/gemma.cpp.git
Expose underlying model configuration: number of layers, heads, etc.
PiperOrigin-RevId: 663747853
This commit is contained in:
parent
301dc8067a
commit
773333e5be
|
|
@ -120,4 +120,21 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
pools_.StopSpinning();
|
||||
}
|
||||
|
||||
template <typename TConfig>
|
||||
struct GetModelConfig {
|
||||
ModelConfigInfo operator()() const {
|
||||
return ModelConfigInfo{
|
||||
.layers = TConfig::kLayers,
|
||||
.model_dim = TConfig::kModelDim,
|
||||
.heads = TConfig::kHeads,
|
||||
.kv_heads = TConfig::kKVHeads,
|
||||
.qkv_dim = TConfig::kQKVDim,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
ModelConfigInfo Gemma::ModelConfig() const {
|
||||
return CallForModel<float, GetModelConfig>(info_.model);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -156,6 +156,15 @@ struct TimingInfo {
|
|||
size_t tokens_generated = 0;
|
||||
};
|
||||
|
||||
// ModelConfigInfo holds model configuration details: number of layers, etc.
|
||||
struct ModelConfigInfo {
|
||||
const int layers;
|
||||
const int model_dim;
|
||||
const int heads;
|
||||
const int kv_heads;
|
||||
const int qkv_dim;
|
||||
};
|
||||
|
||||
class Gemma {
|
||||
public:
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||
|
|
@ -166,6 +175,7 @@ class Gemma {
|
|||
PerClusterPools& pools);
|
||||
~Gemma();
|
||||
|
||||
ModelConfigInfo ModelConfig() const;
|
||||
const ModelInfo& Info() const { return info_; }
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
const ByteStorageT& Weights() const { return weights_u8_; }
|
||||
|
|
|
|||
Loading…
Reference in New Issue