diff --git a/gemma/configs.h b/gemma/configs.h index 0ca14fd..c1fc176 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -73,6 +73,20 @@ constexpr size_t NumLayersOfTypeBefore( return count; } +template +struct CacheLayerSize { + constexpr size_t operator()() const { + return TConfig::kKVHeads * TConfig::kQKVDim * 2; + } +}; + +template +struct CachePosSize { + constexpr size_t operator()() const { + return TConfig::kGemmaLayers * CacheLayerSize()(); + } +}; + struct ConfigNoSSM { static constexpr int kGriffinLayers = 0; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 4f4e8ca..f32cde5 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -69,9 +69,6 @@ struct Activations { static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads; - static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2; - static constexpr size_t kCachePosSize = - TConfig::kGemmaLayers * kCacheLayerSize; static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. @@ -117,12 +114,10 @@ struct CreateKVCache { KVCache operator()() const { KVCache kv_cache = {}; - const size_t size_cache_pos = - TConfig::kGemmaLayers * TConfig::kKVHeads * TConfig::kQKVDim; + const size_t size_cache_pos = CacheLayerSize()(); if (size_cache_pos != 0) { const size_t seq_len = TConfig::kSeqLen + kPrefillBatchSize; - kv_cache.kv_cache = - hwy::AllocateAligned(seq_len * size_cache_pos * 2); + kv_cache.kv_cache = hwy::AllocateAligned(seq_len * size_cache_pos); } if (TConfig::kGriffinLayers) { @@ -372,8 +367,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, using TActivations = Activations; constexpr size_t kQKVDim = TActivations::kQKVDim; constexpr size_t kQStride = TActivations::kQStride; - constexpr size_t kCachePosSize = TActivations::kCachePosSize; - constexpr size_t kCacheLayerSize = TActivations::kCacheLayerSize; + constexpr size_t kCachePosSize = CachePosSize()(); + constexpr size_t kCacheLayerSize = CacheLayerSize()(); constexpr size_t kModelDim = TActivations::kModelDim; constexpr size_t kHeads = TConfig::kHeads; constexpr size_t kKVHeads = TConfig::kKVHeads;