diff --git a/gemma/configs.h b/gemma/configs.h index 9b82880..9388c9d 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -61,12 +61,29 @@ constexpr std::array FixedLayerConfig( return config; } +template +constexpr size_t NumLayersOfTypeBefore( + const std::array& layers, + LayerAttentionType type, size_t num) { + size_t count = 0; + for (size_t i = 0; i < num; i++) { + if (layers[i] == type) count++; + } + return count; +} + struct ConfigGemma7B { static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kVocabSize = 256000; static constexpr std::array kLayerConfig = FixedLayerConfig<28>(LayerAttentionType::kGemma); static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = + NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); + static constexpr int kGriffinLayers = + NumLayersOfTypeBefore(kLayerConfig, + LayerAttentionType::kGriffinRecurrentBlock, + kLayers); static constexpr int kModelDim = 3072; static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kHeads = 16; @@ -91,6 +108,12 @@ struct ConfigGemma2B { static constexpr std::array kLayerConfig = FixedLayerConfig<18>(LayerAttentionType::kGemma); static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = + NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); + static constexpr int kGriffinLayers = + NumLayersOfTypeBefore(kLayerConfig, + LayerAttentionType::kGriffinRecurrentBlock, + kLayers); static constexpr int kModelDim = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kHeads = 8; @@ -143,6 +166,12 @@ struct ConfigGriffin2B { LayerAttentionType::kGriffinRecurrentBlock, }; static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = + NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); + static constexpr int kGriffinLayers = + NumLayersOfTypeBefore(kLayerConfig, + LayerAttentionType::kGriffinRecurrentBlock, + kLayers); static constexpr int kModelDim = 2560; static constexpr int kFFHiddenDim = 7680; static constexpr int kHeads = 10; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 9b8fd93..b64cdea 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -71,30 +71,6 @@ constexpr bool kShowTokenization = false; namespace gcpp { -template -constexpr size_t NumLayersOfTypeBefore( - const std::array& layers, - LayerAttentionType type, size_t num) { - size_t count = 0; - for (size_t i = 0; i < num; i++) { - if (layers[i] == type) count++; - } - return count; -} - -template -constexpr size_t NumGemmaLayers() { - return NumLayersOfTypeBefore(TConfig::kLayerConfig, - LayerAttentionType::kGemma, TConfig::kLayers); -} - -template -constexpr size_t NumGriffinLayers() { - return NumLayersOfTypeBefore(TConfig::kLayerConfig, - LayerAttentionType::kGriffinRecurrentBlock, - TConfig::kLayers); -} - template struct Layer { Layer() = default; @@ -389,7 +365,7 @@ struct Activations { static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads; static constexpr size_t kCachePosSize = - NumGemmaLayers() * kKVHeads * kQKVDim; + TConfig::kGemmaLayers * kKVHeads * kQKVDim; static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim; std::array x; // input @@ -443,11 +419,11 @@ template KVCache CreateKVCache() { constexpr size_t kConv1dWidth = Config::kConv1dWidth; return CreateKVCache( - NumGemmaLayers() * Config::kKVHeads * Config::kQKVDim, + Config::kGemmaLayers * Config::kKVHeads * Config::kQKVDim, Config::kSeqLen, - NumGriffinLayers() * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) * + Config::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) * Config::kModelDim, - NumGriffinLayers() * Config::kModelDim); + Config::kGriffinLayers * Config::kModelDim); } KVCache CreateKVCache(Model type) { diff --git a/gemma/gemma.h b/gemma/gemma.h index abe402c..d2674e1 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -37,13 +37,13 @@ constexpr bool kSystemPrompt = false; struct KVCache { hwy::AlignedFreeUniquePtr - key_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim + key_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim hwy::AlignedFreeUniquePtr - value_cache; // kSeqLen * kNumGemmaLayers * kKVHeads * kQKVDim + value_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim hwy::AlignedFreeUniquePtr - conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kNumGriffinLayers + conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers hwy::AlignedFreeUniquePtr - rglru_cache; // kModelDim * kNumGriffinLayers + rglru_cache; // kModelDim * kGriffinLayers }; // Model variants: see configs.h for details.