From f77e61e514b1cc0e10c8e292d6faaf6a46e57326 Mon Sep 17 00:00:00 2001 From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com> Date: Sat, 19 Oct 2024 13:23:12 +0530 Subject: [PATCH] Use runtime config to setup self extend --- gemma/configs.h | 4 ++++ gemma/gemma-inl.h | 37 ++++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/gemma/configs.h b/gemma/configs.h index f7c6ac2..58f3446 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -127,6 +127,10 @@ struct LayerConfig { size_t conv1d_width = 0; bool ff_biases = false; bool softmax_attn_output_biases = false; + bool self_extend = false; + size_t ngb_size = 0; + size_t grp_size = 1; + PostNormType post_norm = PostNormType::None; LayerAttentionType type = LayerAttentionType::kGemma; ActivationType activation = ActivationType::Gelu; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 196f1a4..ea5aca0 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -312,28 +312,29 @@ class GemmaAttention { const size_t interleaved_idx = task / layer_config_.kv_heads; const size_t query_idx = interleaved_idx % num_queries_; const size_t batch_idx = interleaved_idx / num_queries_; - const size_t pos = queries_pos_[query_idx] + batch_idx; + size_t pos = queries_pos_[query_idx] + batch_idx; const size_t cache_pos = div_seq_len_.Remainder(pos); const size_t kv_offset = cache_pos * cache_pos_size_ + layer_ * cache_layer_size_ + head * layer_config_.qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; + + const size_t grp_size = layer_config_.grp_size; + const size_t ngb_size = layer_config_.ngb_size; + const bool self_extend = layer_config_.self_extend; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; const float* HWY_RESTRICT mha_kv = activations_.q.Batch(interleaved_idx) + head * q_stride_ + layer_config_.qkv_dim; + // When embedding position, we will use grouped key position + if (self_extend && pos > ngb_size) { + pos /= grp_size; + } // Copy from `q` if MHA, or apply in-place. PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f, kv); - - // When embedding position, we will use grouped key position - if constexpr (TConfig::kSelfExtend) { - if (pos > ngb_size) { - pos /= grp_size; - } - } - // If MHA, also copy V into KVCache. if (is_mha_) { hwy::CopyBytes(mha_kv + layer_config_.qkv_dim, @@ -418,19 +419,21 @@ class GemmaAttention { const size_t batch_idx = interleaved_idx / num_queries_; const size_t head_offset = (head / kHeadGroups) * layer_config_.qkv_dim * 2; + + const size_t grp_size = layer_config_.grp_size; + const size_t ngb_size = layer_config_.ngb_size; + const bool self_extend = layer_config_.self_extend; KVCache& kv_cache = kv_caches_[query_idx]; float* HWY_RESTRICT q = activations_.q.Batch(interleaved_idx) + head * q_stride_; // Apply rope and scaling to Q. - const size_t pos = queries_pos_[query_idx] + batch_idx; - if constexpr (TConfig::kSelfExtend) { - if (pos > ngb_size) { - const size_t grp_pos = pos / grp_size; - const size_t shift = ngb_size - ngb_size / grp_size; - const size_t shifted_grouped_pos = grp_pos + shift; - pos = shifted_grouped_pos; - } + size_t pos = queries_pos_[query_idx] + batch_idx; + if (self_extend && pos > ngb_size) { + const size_t grp_pos = pos / grp_size; + const size_t shift = ngb_size - ngb_size / grp_size; + const size_t shifted_grouped_pos = grp_pos + shift; + pos = shifted_grouped_pos; } PositionalEncodingQK(q, pos, layer_, query_scale, q);