Refactor kCachePosSize and kCacheLayerSize into separate functors.

PiperOrigin-RevId: 645048519
This commit is contained in:
The gemma.cpp Authors 2024-06-20 08:51:37 -07:00 committed by Copybara-Service
parent 48ebba8b7a
commit a85725614a
2 changed files with 18 additions and 9 deletions

View File

@ -73,6 +73,20 @@ constexpr size_t NumLayersOfTypeBefore(
return count; return count;
} }
template <class TConfig, typename = void>
struct CacheLayerSize {
constexpr size_t operator()() const {
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
}
};
template <class TConfig, typename = void>
struct CachePosSize {
constexpr size_t operator()() const {
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
}
};
struct ConfigNoSSM { struct ConfigNoSSM {
static constexpr int kGriffinLayers = 0; static constexpr int kGriffinLayers = 0;

View File

@ -69,9 +69,6 @@ struct Activations {
static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
static constexpr size_t kCachePosSize =
TConfig::kGemmaLayers * kCacheLayerSize;
static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
@ -117,12 +114,10 @@ struct CreateKVCache {
KVCache operator()() const { KVCache operator()() const {
KVCache kv_cache = {}; KVCache kv_cache = {};
const size_t size_cache_pos = const size_t size_cache_pos = CacheLayerSize<TConfig>()();
TConfig::kGemmaLayers * TConfig::kKVHeads * TConfig::kQKVDim;
if (size_cache_pos != 0) { if (size_cache_pos != 0) {
const size_t seq_len = TConfig::kSeqLen + kPrefillBatchSize; const size_t seq_len = TConfig::kSeqLen + kPrefillBatchSize;
kv_cache.kv_cache = kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
hwy::AllocateAligned<float>(seq_len * size_cache_pos * 2);
} }
if (TConfig::kGriffinLayers) { if (TConfig::kGriffinLayers) {
@ -372,8 +367,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
using TActivations = Activations<TConfig, kBatchSize>; using TActivations = Activations<TConfig, kBatchSize>;
constexpr size_t kQKVDim = TActivations::kQKVDim; constexpr size_t kQKVDim = TActivations::kQKVDim;
constexpr size_t kQStride = TActivations::kQStride; constexpr size_t kQStride = TActivations::kQStride;
constexpr size_t kCachePosSize = TActivations::kCachePosSize; constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
constexpr size_t kCacheLayerSize = TActivations::kCacheLayerSize; constexpr size_t kCacheLayerSize = CacheLayerSize<TConfig>()();
constexpr size_t kModelDim = TActivations::kModelDim; constexpr size_t kModelDim = TActivations::kModelDim;
constexpr size_t kHeads = TConfig::kHeads; constexpr size_t kHeads = TConfig::kHeads;
constexpr size_t kKVHeads = TConfig::kKVHeads; constexpr size_t kKVHeads = TConfig::kKVHeads;