diff --git a/gemma/attention.cc b/gemma/attention.cc index 570c4f4..5aab57e 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -20,6 +20,7 @@ #include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "util/zones.h" +#include "hwy/base.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -58,8 +59,8 @@ size_t FloatsPerVector() { // The k-cache and v-cache are setup without knowing NF. So if it hasn't been // done already, reshape it to take NF into account. -void MaybeReshapeCache(const MatPtrT& kv, MatPtrT& cache) { - if (kv.Cols() > cache.Cols()) { +void MaybeReshapeCache(const size_t default_cols, MatPtrT& cache) { + if (default_cols == cache.Cols()) { cache.ReshapePackedRowsToCols(2 * FloatsPerVector()); } } @@ -71,13 +72,50 @@ void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, // is a tiny fraction of the overall computation, and it is linear in the // token length. const size_t kFloatsPerTile = 2 * FloatsPerVector(); + const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector); for (size_t i = 0; i < qkv_dim; i += 2) { k[i * kFloatsPerTile] = kv[i]; k[i * kFloatsPerTile + 1] = kv[i + 1]; } + for (size_t i = qkv_dim; i < kRoundedQkvDim; i += 2) { + k[i * kFloatsPerTile] = hwy::ConvertScalarTo(0.0f); + k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo(0.0f); + } for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) { + if (i + kFloatsPerTile <= qkv_dim) { + for (size_t j = 0; j < kFloatsPerTile; j++) { + v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim]; + } + } else { + for (size_t j = 0; j < qkv_dim - i; j++) { + v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim]; + } + for (size_t j = qkv_dim - i; j < kFloatsPerTile; j++) { + v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo(0.0f); + } + } + } + for (size_t i = hwy::RoundUpTo(qkv_dim, kFloatsPerTile); i < kRoundedQkvDim; + i += kFloatsPerTile) { for (size_t j = 0; j < kFloatsPerTile; j++) { - v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim]; + v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo(0.0f); + } + } +} + +// Zeros out a part of k and v that corresponds to out-of-bounds cache +// positions. +void TransposeOOBKVCacheRow(KV_t* HWY_RESTRICT k, KV_t* HWY_RESTRICT v, + size_t qkv_dim) { + const size_t kFloatsPerTile = 2 * FloatsPerVector(); + const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector); + for (size_t i = 0; i < kRoundedQkvDim; i += 2) { + k[i * kFloatsPerTile] = hwy::ConvertScalarTo(0.0f); + k[i * kFloatsPerTile + 1] = hwy::ConvertScalarTo(0.0f); + } + for (size_t i = 0; i < kRoundedQkvDim; i += kFloatsPerTile) { + for (size_t j = 0; j < kFloatsPerTile; j++) { + v[i * kFloatsPerTile + j] = hwy::ConvertScalarTo(0.0f); } } } @@ -314,16 +352,22 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, /*add=*/nullptr, env, kv_rows); for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache); - MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache); + MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(), + qbatch.KV(qi).k_cache); + MaybeReshapeCache(qbatch.KV(qi).cache->KOrVDefaultCols(), + qbatch.KV(qi).v_cache); } const size_t kFloatsPerVector = FloatsPerVector(); + const size_t kRoundedTokens = + hwy::RoundUpTo(num_tokens, 2 * kFloatsPerVector); + const size_t kRoundedNumInterleaved = + kRoundedTokens * div_qbatch.GetDivisor(); // Apply positional encodings for K. // Note that 2D parallelism is not worth the fork/join overhead because the // tasks are very lightweight. ParallelFor( - Parallelism::kFlat, kv_heads * num_interleaved, env.ctx, + Parallelism::kFlat, kv_heads * kRoundedNumInterleaved, env.ctx, /*cluster_idx=*/0, Callers::kAttComputeQKV, [&](size_t task, size_t worker) HWY_ATTR { const size_t head = task % kv_heads; @@ -331,6 +375,28 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t token_idx = div_qbatch.Divide(interleaved_idx); const size_t cache_pos = qbatch.Pos(qi) + token_idx; + if (token_idx >= kRoundedTokens) { + return; + } + // The innermost dimension of v is 2NF values from qkv_dim because they + // will be loaded into a BF16 vector to be scaled and added to the + // cached attention output in 2 NF-sized registers. + auto& k_cache = qbatch.KV(qi).k_cache; + KV_t* HWY_RESTRICT k = + k_cache.Row(cache_pos / (2 * kFloatsPerVector)) + + qbatch.KV(qi).cache->KOffset(layer_idx, head, kFloatsPerVector, + cache_pos); + auto& v_cache = qbatch.KV(qi).v_cache; + KV_t* HWY_RESTRICT v = + v_cache.Row(cache_pos / (2 * kFloatsPerVector)) + + qbatch.KV(qi).cache->VOffset(layer_idx, head, kFloatsPerVector, + cache_pos); + if (token_idx >= num_tokens) { + // Create a zero-filled K/V pair for padding for out-of-sequence + // tokens. + TransposeOOBKVCacheRow(k, v, qkv_dim); + return; + } // --seq_len must be large enough to avoid wraparound. HWY_DASSERT(cache_pos < activations.SeqLen()); auto& kv_cache = qbatch.KV(qi).kv_cache; @@ -341,22 +407,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // The innermost dimension of k is 2 values from qkv_dim because they // are going to be used in a BF16 dot product involving pairs of // values over NF k positions. - // The innermost dimension of v is 2NF values from qkv_dim because they - // will be loaded into a BF16 vector to be scaled and added to the - // cached attention output in 2 NF-sized registers. - // TODO(rays): factor out these calculations into functions. - auto& k_cache = qbatch.KV(qi).k_cache; - KV_t* HWY_RESTRICT k = - k_cache.Row(cache_pos / (2 * kFloatsPerVector)) + - (layer_idx * cache_layer_size + head * qkv_dim * 2) * - kFloatsPerVector + - (cache_pos % (2 * kFloatsPerVector)) * 2; - auto& v_cache = qbatch.KV(qi).v_cache; - KV_t* HWY_RESTRICT v = - v_cache.Row(cache_pos / (2 * kFloatsPerVector)) + - (layer_idx * cache_layer_size + head * qkv_dim * 2) * - kFloatsPerVector + - (cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector; HWY_ALIGN float kv_f32[2 * kMaxQKVDim]; const hn::ScalableTag df; diff --git a/gemma/attention.h b/gemma/attention.h index 14870de..bb8a743 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -33,7 +33,7 @@ namespace gcpp { namespace NAMESPACE { \ size_t FloatsPerVector(); \ \ - void MaybeReshapeCache(const MatPtrT& kv, MatPtrT& cache); \ + void MaybeReshapeCache(size_t default_cols, MatPtrT& cache); \ \ void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \ KV_t* HWY_RESTRICT v, size_t qkv_dim); \ diff --git a/gemma/configs.h b/gemma/configs.h index 3811e98..0c5dbe8 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -29,9 +29,12 @@ #include "io/fields.h" // IFieldsVisitor #include "io/io.h" // Path #include "util/basics.h" +#include "hwy/detect_compiler_arch.h" namespace gcpp { +constexpr size_t kMaxBF16PerVector = HWY_ARCH_MAX_BYTES / sizeof(BF16); + HWY_INLINE_VAR constexpr int kAttentionUseOld = 2; HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024; diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 7ed1d69..2adab18 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -1700,7 +1700,6 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism, // A "head group" in the context of GQA refers to a collection of query // heads that share the same key and value heads. const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; - const size_t cache_layer_size = layer_config.CacheLayerSize(); const size_t token_batch = num_tokens * div_qbatch.GetDivisor(); const size_t total_tasks = token_batch * layer_config.heads; size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, num_tokens, total_tasks, @@ -1716,11 +1715,9 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism, params.clear(); for (uint32_t qi = 0; qi < div_qbatch.GetDivisor(); ++qi) { for (uint32_t kv_head = 0; kv_head < layer_config.kv_heads; ++kv_head) { - const size_t head_offset = kv_head * qkv_dim * 2; - const uint32_t kv_offset = layer_idx * cache_layer_size + head_offset; params.push_back(Tile148Params{ .qi_index = qi, - .kv_offset = kv_offset, + .kv_head = kv_head, }); for (uint32_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { const size_t pos = qbatch.Pos(qi) + batch_idx; @@ -1746,7 +1743,7 @@ void ComputeFlashParams(size_t num_tokens, const size_t target_parallelism, // current tile is full so start new tile. params.push_back(Tile148Params{ .qi_index = qi, - .kv_offset = kv_offset, + .kv_head = kv_head, }); } const size_t head = head_group + kHeadGroups * kv_head; @@ -2157,13 +2154,20 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention); auto& param = params[task]; auto& kT_cache = qbatch.KV(param.qi_index).k_cache; + const size_t kRoundedQkvDim = hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector); MatPtrT kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF), - qkv_dim * 2 * kNF)); - kT.SetPtr(kT_cache.Row(0) + param.kv_offset * kNF, kT_cache.Stride()); + kRoundedQkvDim * 2 * kNF)); + kT.SetPtr( + kT_cache.Row(0) + qbatch.KV(param.qi_index) + .cache->KOrVOffset(layer_idx, param.kv_head, kNF), + kT_cache.Stride()); auto& vT_cache = qbatch.KV(param.qi_index).v_cache; MatPtrT vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF), - qkv_dim * 2 * kNF)); - vT.SetPtr(vT_cache.Row(0) + param.kv_offset * kNF, vT_cache.Stride()); + kRoundedQkvDim * 2 * kNF)); + vT.SetPtr( + vT_cache.Row(0) + qbatch.KV(param.qi_index) + .cache->KOrVOffset(layer_idx, param.kv_head, kNF), + vT_cache.Stride()); MatPtrT& att_out = param.i_of_n == 0 ? activations.att_out : activations.att_out_reps; DispatchTileFlashAttention148(param, activations.q_bf, kT, vT, layer_idx, diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index fd693d9..446ad51 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -144,10 +144,15 @@ void TestFlashAttention(size_t target_parallelism, const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; const size_t seq_len = static_cast(attention.div_seq_len.GetDivisor()); - MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).k_cache); - MaybeReshapeCache(qbatch.KV(0).kv_cache, qbatch.KV(0).v_cache); + MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(), + qbatch.KV(0).k_cache); + MaybeReshapeCache(qbatch.KV(0).cache->KOrVDefaultCols(), + qbatch.KV(0).v_cache); auto& kvc = qbatch.KV(0).kv_cache; - const size_t kFloatsPerTile = 2 * FloatsPerVector(); + using DF = hn::ScalableTag; + const DF df; + const size_t kNF = hn::Lanes(df); + const size_t kFloatsPerTile = 2 * kNF; for (size_t h = 0; h < layer_config.heads; ++h) { // Make strided views into the kv cache for // this query and head. @@ -160,12 +165,12 @@ void TestFlashAttention(size_t target_parallelism, SetMat(h + layer_config.heads * 2, v); for (size_t p = 0; p < tokens.size(); ++p) { KV_t* HWY_RESTRICT k_src = k.Row(p); - KV_t* HWY_RESTRICT k_dest = qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) + - head_offset * kFloatsPerTile / 2 + - p % kFloatsPerTile * 2; - KV_t* HWY_RESTRICT v_dest = qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) + - head_offset * kFloatsPerTile / 2 + - p % kFloatsPerTile * kFloatsPerTile; + KV_t* HWY_RESTRICT k_dest = + qbatch.KV(0).k_cache.Row(p / kFloatsPerTile) + + qbatch.KV(0).cache->KOffset(0, h / kHeadGroups, kNF, p); + KV_t* HWY_RESTRICT v_dest = + qbatch.KV(0).v_cache.Row(p / kFloatsPerTile) + + qbatch.KV(0).cache->VOffset(0, h / kHeadGroups, kNF, p); TransposeKVCacheRow(k_src, k_dest, v_dest, qkv_dim); } @@ -176,9 +181,6 @@ void TestFlashAttention(size_t target_parallelism, // Copy the output to saved_att to allow for comparison. auto saved_att = MakeCopyOfMat(attention.att_out, ctx.allocator); SetMat(1, attention.q); - using DF = hn::ScalableTag; - const DF df; - const size_t kNF = hn::Lanes(df); const size_t total_tasks = tokens.size() * div_qbatch.GetDivisor() * layer_config.heads; const size_t kVTileSize = GetVTileSize(kNF, kHeadGroups, tokens.size(), diff --git a/gemma/flash_structs.h b/gemma/flash_structs.h index 8c446e0..5473692 100644 --- a/gemma/flash_structs.h +++ b/gemma/flash_structs.h @@ -48,8 +48,8 @@ struct Tile148Params { uint32_t max_last_pos = 0; // Index into the qbatch.KV is the same for each row in the tile. uint32_t qi_index; - // Index into the kv_cache is the same for each row in the tile. - uint32_t kv_offset; + // kv_head is the same for each row in the tile. + uint32_t kv_head; // In the original task, the index to the split tasks of the first split task. uint32_t split_index = 0; // The index of the split for running split attention. diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index f33cd21..d94f917 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -29,11 +29,6 @@ namespace gcpp { -// TODO: rays - Remove this once hwy is updated. -#ifndef HWY_ARCH_MAX_BYTES -#define HWY_ARCH_MAX_BYTES 256 -#endif - // Number of rows for KV cache. Note that both rows and cols are u32, and // the total number of elements can exceed 2^32. static size_t CappedSeqLen(const ModelConfig& config, @@ -46,8 +41,13 @@ static size_t CappedSeqLen(const ModelConfig& config, return inference_args.seq_len; } -KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator) - : kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), +KVCache::KVCache(const Extents2D& kv_extents, size_t num_layers, + size_t kv_heads, size_t qkv_dim, const Allocator& allocator) + : num_layers(num_layers), + kv_heads(kv_heads), + qkv_dim(qkv_dim), + rounded_qkv_dim(hwy::RoundUpTo(qkv_dim, kMaxBF16PerVector)), + kv_cache("kv", kv_extents, allocator, MatPadding::kOdd), // WARNING: the rows and cols of k_cache and v_cache will be modified // before use! // The rows will be reduced by a factor of 2xkFloatsPerVector, and the @@ -56,14 +56,12 @@ KVCache::KVCache(const Extents2D& kv_extents, const Allocator& allocator) // machine architecture, since kFloatsPerVector is architecture dependent. // The change is shape is safe only if the padding is kPacked. k_cache("k", - Extents2D(HWY_MAX(kv_extents.rows, - 2 * HWY_ARCH_MAX_BYTES / sizeof(float)), - kv_extents.cols / 2), + Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector), + KOrVDefaultCols()), allocator, MatPadding::kPacked), v_cache("v", - Extents2D(HWY_MAX(kv_extents.rows, - 2 * HWY_ARCH_MAX_BYTES / sizeof(float)), - kv_extents.cols / 2), + Extents2D(hwy::RoundUpTo(kv_extents.rows, kMaxBF16PerVector), + KOrVDefaultCols()), allocator, MatPadding::kPacked), allocator_(allocator) {} @@ -71,7 +69,8 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const Allocator& allocator) : KVCache( Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()), - allocator) {} + config.layer_configs.size(), config.layer_configs[0].kv_heads, + config.layer_configs[0].qkv_dim, allocator) {} KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, const RuntimeConfig& runtime_config, @@ -135,7 +134,7 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, } KVCache KVCache::Copy() { - KVCache copy(kv_cache.Extents(), allocator_); + KVCache copy(kv_cache.Extents(), num_layers, kv_heads, qkv_dim, allocator_); CopyMat(kv_cache, copy.kv_cache); return copy; diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 5fe1f1e..dab636b 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -91,6 +91,38 @@ struct KVCache { return {start_ptr, source_ptr}; } + // Returns the default size of a row in k_cache or v_cache, before scaling by + // 2 * kNF. + size_t KOrVDefaultCols() const { + return num_layers * kv_heads * rounded_qkv_dim; + } + + // Returns an offset into a row of k_cache or v_cache at a position that is + // aligned to the tile size (a multiple of 2kNF). + size_t KOrVOffset(const size_t layer_idx, const size_t kv_head_idx, + const size_t kNF) const { + return (layer_idx * kv_heads + kv_head_idx) * rounded_qkv_dim * 2 * kNF; + } + + // Returns an offset into k_cache at any given position. + size_t KOffset(const size_t layer_idx, const size_t kv_head_idx, + const size_t kNF, const size_t pos) const { + return KOrVOffset(layer_idx, kv_head_idx, kNF) + (pos % (2 * kNF)) * 2; + } + + // Returns an offset into v_cache at any given position. + size_t VOffset(const size_t layer_idx, const size_t kv_head_idx, + const size_t kNF, const size_t pos) const { + return KOrVOffset(layer_idx, kv_head_idx, kNF) + + (pos % (2 * kNF)) * 2 * kNF; + } + + // Saved sizes for computing offsets into the KV cache. + size_t num_layers = 0; + size_t kv_heads = 0; + size_t qkv_dim = 0; + size_t rounded_qkv_dim = 0; + static constexpr size_t kTileSize = 32; std::optional tiled_seq_len = std::nullopt; // Default Format @@ -159,7 +191,8 @@ struct KVCache { const Allocator& allocator_; // For use by other ctor and Copy() - KVCache(const Extents2D& kv_extents, const Allocator& allocator); + KVCache(const Extents2D& kv_extents, size_t num_layers, size_t kv_heads, + size_t qkv_dim, const Allocator& allocator); }; inline size_t KVCachePtr::SeqLen() const {