diff --git a/gemma/attention.cc b/gemma/attention.cc index 5aab57e..570c4f4 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -20,7 +20,6 @@ #include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "util/zones.h" -#include "hwy/base.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -59,8 +58,8 @@ size_t FloatsPerVector() { // The k-cache and v-cache are setup without knowing NF. So if it hasn't been // done already, reshape it to take NF into account. -void MaybeReshapeCache(const size_t default_cols, MatPtrT& cache) { - if (default_cols == cache.Cols()) { +void MaybeReshapeCache(const MatPtrT& kv, MatPtrT& cache) { + if (kv.Cols() > cache.Cols()) { cache.ReshapePackedRowsToCols(2 * FloatsPerVector()); } } @@ -72,50 +71,13 @@ void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, // is a tiny fraction of the overall computation, and it is linear in the // token length. const size_t kFloatsPerTile = 2 * FloatsPerVector(); - const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector); for (size_t i = 0; i < qkv_dim; i += 2) { k[i * kFloatsPerTile] = kv[i]; k[i * kFloatsPerTile + 1] = kv[i + 1]; } - for (size_t i = qkv_dim; i < kRoundedQkvDim; i += 2) { - k[i * kFloatsPerTile] = hwy::ConvertScalarTo(0.0f); - k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo(0.0f); - } for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) { - if (i + kFloatsPerTile <= qkv_dim) { - for (size_t j = 0; j < kFloatsPerTile; j++) { - v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim]; - } - } else { - for (size_t j = 0; j < qkv_dim - i; j++) { - v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim]; - } - for (size_t j = qkv_dim - i; j < kFloatsPerTile; j++) { - v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo(0.0f); - } - } - } - for (size_t i = hwy::RoundUpTo(qkv_dim, kFloatsPerTile); i < kRoundedQkvDim; - i += kFloatsPerTile) { for (size_t j = 0; j < kFloatsPerTile; j++) { - v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo(0.0f); - } - } -} - -// Zeros out a part of k and v that corresponds to out-of-bounds cache -// positions. -void TransposeOOBKVCacheRow(KV_t* HWY_RESTRICT k, KV_t* HWY_RESTRICT v, - size_t qkv_dim) { - const size_t kFloatsPerTile = 2 * FloatsPerVector(); - const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector); - for (size_t i = 0; i < kRoundedQkvDim; i += 2) { - k[i * kFloatsPerTile] = hwy::ConvertScalarTo(0.0f); - k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo(0.0f); - } - for (size_t i = 0; i < kRoundedQkvDim; i += kFloatsPerTile) { - for (size_t j = 0; j < kFloatsPerTile; j++) { - v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo(0.0f); + v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim]; } } } @@ -352,22 +314,16 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, /*add=*/nullptr, env, kv_rows); for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(), - qbatch.KV(qi).k_cache); - MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(), - qbatch.KV(qi).v_cache); + MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache); + MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache); } const size_t kFloatsPerVector = FloatsPerVector(); - const size_t kRoundedTokens = - hwy::RoundUpTo(num_tokens, 2 * kFloatsPerVector); - const size_t kRoundedNumInterleaved = - kRoundedTokens * div_qbatch.GetDivisor(); // Apply positional encodings for K. // Note that 2D parallelism is not worth the fork/join overhead because the // tasks are very lightweight. ParallelFor( - Parallelism::kFlat, kv_heads * kRoundedNumInterleaved, env.ctx, + Parallelism::kFlat, kv_heads * num_interleaved, env.ctx, /*cluster_idx=*/0, Callers::kAttComputeQKV, [&](size_t task, size_t worker) HWY_ATTR { const size_t head = task % kv_heads; @@ -375,28 +331,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t token_idx = div_qbatch.Divide(interleaved_idx); const size_t cache_pos = qbatch.Pos(qi) + token_idx; - if (token_idx >= kRoundedTokens) { - return; - } - // The innermost dimension of v is 2NF values from qkv_dim because they - // will be loaded into a BF16 vector to be scaled and added to the - // cached attention output in 2 NF-sized registers. - auto& k_cache = qbatch.KV(qi).k_cache; - KV_t* HWY_RESTRICT k = - k_cache.Row(cache_pos / (2 * kFloatsPerVector)) + - qbatch.KV(qi).cache->KOffset(layer_idx, head, kFloatsPerVector, - cache_pos); - auto& v_cache = qbatch.KV(qi).v_cache; - KV_t* HWY_RESTRICT v = - v_cache.Row(cache_pos / (2 * kFloatsPerVector)) + - qbatch.KV(qi).cache->VOffset(layer_idx, head, kFloatsPerVector, - cache_pos); - if (token_idx >= num_tokens) { - // Create a zero-filled K/V pair for padding for out-of-sequence - // tokens. - TransposeOOBKVCacheRow(k, v, qkv_dim); - return; - } // --seq_len must be large enough to avoid wraparound. HWY_DASSERT(cache_pos < activations.SeqLen()); auto& kv_cache = qbatch.KV(qi).kv_cache; @@ -407,6 +341,22 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // The innermost dimension of k is 2 values from qkv_dim because they // are going to be used in a BF16 dot product involving pairs of // values over NF k positions. + // The innermost dimension of v is 2NF values from qkv_dim because they + // will be loaded into a BF16 vector to be scaled and added to the + // cached attention output in 2 NF-sized registers. + // TODO(rays): factor out these calculations into functions. + auto& k_cache = qbatch.KV(qi).k_cache; + KV_t* HWY_RESTRICT k = + k_cache.Row(cache_pos / (2 * kFloatsPerVector)) + + (layer_idx * cache_layer_size + head * qkv_dim * 2) * + kFloatsPerVector + + (cache_pos % (2 * kFloatsPerVector)) * 2; + auto& v_cache = qbatch.KV(qi).v_cache; + KV_t* HWY_RESTRICT v = + v_cache.Row(cache_pos / (2 * kFloatsPerVector)) + + (layer_idx * cache_layer_size + head * qkv_dim * 2) * + kFloatsPerVector + + (cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector; HWY_ALIGN float kv_f32[2 * kMaxQKVDim]; const hn::ScalableTag df; diff --git a/gemma/attention.h b/gemma/attention.h index bb8a743..14870de 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -33,7 +33,7 @@ namespace gcpp { namespace NAMESPACE { \ size_t FloatsPerVector(); \ \ - void MaybeReshapeCache(size_t default_cols, MatPtrT& cache); \ + void MaybeReshapeCache(const MatPtrT& kv, MatPtrT& cache); \ \ void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \ KV_t* HWY_RESTRICT v, size_t qkv_dim); \ diff --git a/gemma/configs.h b/gemma/configs.h index 0c5dbe8..3811e98 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -29,12 +29,9 @@ #include "io/fields.h" // IFieldsVisitor #include "io/io.h" // Path #include "util/basics.h" -#include "hwy/detect_compiler_arch.h" namespace gcpp { -constexpr size_t kMaxBF16PerVector = HWY_ARCH_MAX_BYTES / sizeof(BF16); - HWY_INLINE_VAR constexpr int kAttentionUseOld = 2; HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024; diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 2adab18..7ed1d69 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -1700,6 +1700,7 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism, // A "head group" in the context of GQA refers to a collection of query // heads that share the same key and value heads. const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; + const size_t cache_layer_size = layer_config.CacheLayerSize(); const size_t token_batch = num_tokens * div_qbatch.GetDivisor(); const size_t total_tasks = token_batch * layer_config.heads; size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens, total_tasks, @@ -1715,9 +1716,11 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism, params.clear(); for (uint32_t qi = 0; qi < div_qbatch.GetDivisor(); ++qi) { for (uint32_t kv_head = 0; kv_head < layer_config.kv_heads; ++kv_head) { + const size_t head_offset = kv_head * qkv_dim * 2; + const uint32_t kv_offset = layer_idx * cache_layer_size + head_offset; params.push_back(Tile148Params{ .qi_index = qi, - .kv_head = kv_head, + .kv_offset = kv_offset, }); for (uint32_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { const size_t pos = qbatch.Pos(qi) + batch_idx; @@ -1743,7 +1746,7 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism, // current tile is full so start new tile. params.push_back(Tile148Params{ .qi_index = qi, - .kv_head = kv_head, + .kv_offset = kv_offset, }); } const size_t head = head_group + kHeadGroups * kv_head; @@ -2154,20 +2157,13 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention); auto& param = params[task]; auto& kT_cache = qbatch.KV(param.qi_index).k_cache; - const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector); MatPtrT kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF), - kRoundedQkvDim * 2 * kNF)); - kT.SetPtr( - kT_cache.Row(0) + qbatch.KV(param.qi_index) - .cache->KOrVOffset(layer_idx, param.kv_head, kNF), - kT_cache.Stride()); + qkv_dim * 2 * kNF)); + kT.SetPtr(kT_cache.Row(0) + param.kv_offset * kNF, kT_cache.Stride()); auto& vT_cache = qbatch.KV(param.qi_index).v_cache; MatPtrT vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF), - kRoundedQkvDim * 2 * kNF)); - vT.SetPtr( - vT_cache.Row(0) + qbatch.KV(param.qi_index) - .cache->KOrVOffset(layer_idx, param.kv_head, kNF), - vT_cache.Stride()); + qkv_dim * 2 * kNF)); + vT.SetPtr(vT_cache.Row(0) + param.kv_offset * kNF, vT_cache.Stride()); MatPtrT& att_out = param.i_of_n == 0 ? activations.att_out : activations.att_out_reps; DispatchTileFlashAttention148(param, activations.q_bf, kT, vT, layer_idx, diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 446ad51..fd693d9 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -144,15 +144,10 @@ void TestFlashAttention(size_t target_parallelism, const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; const size_t seq_len = static_cast(attention.div_seq_len.GetDivisor()); - MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(), - qbatch.KV(0).k_cache); - MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(), - qbatch.KV(0).v_cache); + MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache); + MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache); auto& kvc = qbatch.KV(0).kv_cache; - using DF = hn::ScalableTag; - const DF df; - const size_t kNF = hn::Lanes(df); - const size_t kFloatsPerTile = 2 * kNF; + const size_t kFloatsPerTile = 2 * FloatsPerVector(); for (size_t h = 0; h < layer_config.heads; ++h) { // Make strided views into the kv cache for // this query and head. @@ -165,12 +160,12 @@ void TestFlashAttention(size_t target_parallelism, SetMat(h + layer_config.heads * 2, v); for (size_t p = 0; p < tokens.size(); ++p) { KV_t* HWY_RESTRICT k_src = k.Row(p); - KV_t* HWY_RESTRICT k_dest = - qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) + - qbatch.KV(0).cache->KOffset(0, h / kHeadGroups, kNF, p); - KV_t* HWY_RESTRICT v_dest = - qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) + - qbatch.KV(0).cache->VOffset(0, h / kHeadGroups, kNF, p); + KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) + + head_offset * kFloatsPerTile / 2 + + p % kFloatsPerTile * 2; + KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) + + head_offset * kFloatsPerTile / 2 + + p % kFloatsPerTile * kFloatsPerTile; TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim); } @@ -181,6 +176,9 @@ void TestFlashAttention(size_t target_parallelism, // Copy the output to saved_att to allow for comparison. auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); SetMat(1, attention.q); + using DF = hn::ScalableTag; + const DF df; + const size_t kNF = hn::Lanes(df); const size_t total_tasks = tokens.size() * div_qbatch.GetDivisor() * layer_config.heads; const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(), diff --git a/gemma/flash_structs.h b/gemma/flash_structs.h index 5473692..8c446e0 100644 --- a/gemma/flash_structs.h +++ b/gemma/flash_structs.h @@ -48,8 +48,8 @@ struct Tile148Params { uint32_t max_last_pos = 0; // Index into the qbatch.KV is the same for each row in the tile. uint32_t qi_index; - // kv_head is the same for each row in the tile. - uint32_t kv_head; + // Index into the kv_cache is the same for each row in the tile. + uint32_t kv_offset; // In the original task, the index to the split tasks of the first split task. uint32_t split_index = 0; // The index of the split for running split attention. diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index d94f917..f33cd21 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -29,6 +29,11 @@ namespace gcpp { +// TODO: rays - Remove this once hwy is updated. +#ifndef HWY_ARCH_MAX_BYTES +#define HWY_ARCH_MAX_BYTES 256 +#endif + // Number of rows for KV cache. Note that both rows and cols are u32, and // the total number of elements can exceed 2^32. static size_t CappedSeqLen(const ModelConfig& config, @@ -41,13 +46,8 @@ static size_t CappedSeqLen(const ModelConfig& config, return inference_args.seq_len; } -KVCache::KVCache(const Extents2D& kv_extents, size_t num_layers, - size_t kv_heads, size_t qkv_dim, const Allocator& allocator) - : num_layers(num_layers), - kv_heads(kv_heads), - qkv_dim(qkv_dim), - rounded_qkv_dim(hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector)), - kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), +KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator) + : kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), // WARNING: the rows and cols of k_cache and v_cache will be modified // before use! // The rows will be reduced by a factor of 2xkFloatsPerVector, and the @@ -56,12 +56,14 @@ KVCache::KVCache(const Extents2D& kv_extents, size_t num_layers, // machine architecture, since kFloatsPerVector is architecture dependent. // The change is shape is safe only if the padding is kPacked. k_cache("k", - Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector), - KOrVDefaultCols()), + Extents2D(HWY_MAX(kv_extents.rows, + 2 * HWY_ARCH_MAX_BYTES / sizeof(float)), + kv_extents.cols / 2), allocator, MatPadding::kPacked), v_cache("v", - Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector), - KOrVDefaultCols()), + Extents2D(HWY_MAX(kv_extents.rows, + 2 * HWY_ARCH_MAX_BYTES / sizeof(float)), + kv_extents.cols / 2), allocator, MatPadding::kPacked), allocator_(allocator) {} @@ -69,8 +71,7 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const Allocator& allocator) : KVCache( Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()), - config.layer_configs.size(), config.layer_configs[0].kv_heads, - config.layer_configs[0].qkv_dim, allocator) {} + allocator) {} KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const RuntimeConfig& runtime_config, @@ -134,7 +135,7 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, } KVCache KVCache::Copy() { - KVCache copy(kv_cache.Extents(), num_layers, kv_heads, qkv_dim, allocator_); + KVCache copy(kv_cache.Extents(), allocator_); CopyMat(kv_cache, copy.kv_cache); return copy; diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index dab636b..5fe1f1e 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -91,38 +91,6 @@ struct KVCache { return {start_ptr, source_ptr}; } - // Returns the default size of a row in k_cache or v_cache, before scaling by - // 2 * kNF. - size_t KOrVDefaultCols() const { - return num_layers * kv_heads * rounded_qkv_dim; - } - - // Returns an offset into a row of k_cache or v_cache at a position that is - // aligned to the tile size (a multiple of 2kNF). - size_t KOrVOffset(const size_t layer_idx, const size_t kv_head_idx, - const size_t kNF) const { - return (layer_idx * kv_heads + kv_head_idx) * rounded_qkv_dim * 2 * kNF; - } - - // Returns an offset into k_cache at any given position. - size_t KOffset(const size_t layer_idx, const size_t kv_head_idx, - const size_t kNF, const size_t pos) const { - return KOrVOffset(layer_idx, kv_head_idx, kNF) + (pos % (2 * kNF)) * 2; - } - - // Returns an offset into v_cache at any given position. - size_t VOffset(const size_t layer_idx, const size_t kv_head_idx, - const size_t kNF, const size_t pos) const { - return KOrVOffset(layer_idx, kv_head_idx, kNF) + - (pos % (2 * kNF)) * 2 * kNF; - } - - // Saved sizes for computing offsets into the KV cache. - size_t num_layers = 0; - size_t kv_heads = 0; - size_t qkv_dim = 0; - size_t rounded_qkv_dim = 0; - static constexpr size_t kTileSize = 32; std::optional tiled_seq_len = std::nullopt; // Default Format @@ -191,8 +159,7 @@ struct KVCache { const Allocator& allocator_; // For use by other ctor and Copy() - KVCache(const Extents2D& kv_extents, size_t num_layers, size_t kv_heads, - size_t qkv_dim, const Allocator& allocator); + KVCache(const Extents2D& kv_extents, const Allocator& allocator); }; inline size_t KVCachePtr::SeqLen() const {