mirror of https://github.com/google/gemma.cpp.git
Refactor kCachePosSize and kCacheLayerSize into separate functors.
PiperOrigin-RevId: 645048519
This commit is contained in:
parent
48ebba8b7a
commit
a85725614a
|
|
@ -73,6 +73,20 @@ constexpr size_t NumLayersOfTypeBefore(
|
||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class TConfig, typename = void>
|
||||||
|
struct CacheLayerSize {
|
||||||
|
constexpr size_t operator()() const {
|
||||||
|
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class TConfig, typename = void>
|
||||||
|
struct CachePosSize {
|
||||||
|
constexpr size_t operator()() const {
|
||||||
|
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct ConfigNoSSM {
|
struct ConfigNoSSM {
|
||||||
static constexpr int kGriffinLayers = 0;
|
static constexpr int kGriffinLayers = 0;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -69,9 +69,6 @@ struct Activations {
|
||||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2;
|
|
||||||
static constexpr size_t kCachePosSize =
|
|
||||||
TConfig::kGemmaLayers * kCacheLayerSize;
|
|
||||||
static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention
|
static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention
|
||||||
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||||
|
|
@ -117,12 +114,10 @@ struct CreateKVCache {
|
||||||
KVCache operator()() const {
|
KVCache operator()() const {
|
||||||
KVCache kv_cache = {};
|
KVCache kv_cache = {};
|
||||||
|
|
||||||
const size_t size_cache_pos =
|
const size_t size_cache_pos = CacheLayerSize<TConfig>()();
|
||||||
TConfig::kGemmaLayers * TConfig::kKVHeads * TConfig::kQKVDim;
|
|
||||||
if (size_cache_pos != 0) {
|
if (size_cache_pos != 0) {
|
||||||
const size_t seq_len = TConfig::kSeqLen + kPrefillBatchSize;
|
const size_t seq_len = TConfig::kSeqLen + kPrefillBatchSize;
|
||||||
kv_cache.kv_cache =
|
kv_cache.kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||||
hwy::AllocateAligned<float>(seq_len * size_cache_pos * 2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (TConfig::kGriffinLayers) {
|
if (TConfig::kGriffinLayers) {
|
||||||
|
|
@ -372,8 +367,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
using TActivations = Activations<TConfig, kBatchSize>;
|
using TActivations = Activations<TConfig, kBatchSize>;
|
||||||
constexpr size_t kQKVDim = TActivations::kQKVDim;
|
constexpr size_t kQKVDim = TActivations::kQKVDim;
|
||||||
constexpr size_t kQStride = TActivations::kQStride;
|
constexpr size_t kQStride = TActivations::kQStride;
|
||||||
constexpr size_t kCachePosSize = TActivations::kCachePosSize;
|
constexpr size_t kCachePosSize = CachePosSize<TConfig>()();
|
||||||
constexpr size_t kCacheLayerSize = TActivations::kCacheLayerSize;
|
constexpr size_t kCacheLayerSize = CacheLayerSize<TConfig>()();
|
||||||
constexpr size_t kModelDim = TActivations::kModelDim;
|
constexpr size_t kModelDim = TActivations::kModelDim;
|
||||||
constexpr size_t kHeads = TConfig::kHeads;
|
constexpr size_t kHeads = TConfig::kHeads;
|
||||||
constexpr size_t kKVHeads = TConfig::kKVHeads;
|
constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue