Use runtime config to setup self extend

This commit is contained in:
Nanubala Gnana Sai 2024-10-19 13:23:12 +05:30
parent fbba1972d0
commit f77e61e514
2 changed files with 24 additions and 17 deletions

View File

@ -127,6 +127,10 @@ struct LayerConfig {
size_t conv1d_width = 0;
bool ff_biases = false;
bool softmax_attn_output_biases = false;
bool self_extend = false;
size_t ngb_size = 0;
size_t grp_size = 1;
PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma;
ActivationType activation = ActivationType::Gelu;

View File

@ -312,28 +312,29 @@ class GemmaAttention {
const size_t interleaved_idx = task / layer_config_.kv_heads;
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t pos = queries_pos_[query_idx] + batch_idx;
size_t pos = queries_pos_[query_idx] + batch_idx;
const size_t cache_pos = div_seq_len_.Remainder(pos);
const size_t kv_offset = cache_pos * cache_pos_size_ +
layer_ * cache_layer_size_ +
head * layer_config_.qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx];
const size_t grp_size = layer_config_.grp_size;
const size_t ngb_size = layer_config_.ngb_size;
const bool self_extend = layer_config_.self_extend;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
const float* HWY_RESTRICT mha_kv =
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
layer_config_.qkv_dim;
// When embedding position, we will use grouped key position
if (self_extend && pos > ngb_size) {
pos /= grp_size;
}
// Copy from `q` if MHA, or apply in-place.
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
kv);
// When embedding position, we will use grouped key position
if constexpr (TConfig::kSelfExtend) {
if (pos > ngb_size) {
pos /= grp_size;
}
}
// If MHA, also copy V into KVCache.
if (is_mha_) {
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
@ -418,19 +419,21 @@ class GemmaAttention {
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t head_offset =
(head / kHeadGroups) * layer_config_.qkv_dim * 2;
const size_t grp_size = layer_config_.grp_size;
const size_t ngb_size = layer_config_.ngb_size;
const bool self_extend = layer_config_.self_extend;
KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * q_stride_;
// Apply rope and scaling to Q.
const size_t pos = queries_pos_[query_idx] + batch_idx;
if constexpr (TConfig::kSelfExtend) {
if (pos > ngb_size) {
const size_t grp_pos = pos / grp_size;
const size_t shift = ngb_size - ngb_size / grp_size;
const size_t shifted_grouped_pos = grp_pos + shift;
pos = shifted_grouped_pos;
}
size_t pos = queries_pos_[query_idx] + batch_idx;
if (self_extend && pos > ngb_size) {
const size_t grp_pos = pos / grp_size;
const size_t shift = ngb_size - ngb_size / grp_size;
const size_t shifted_grouped_pos = grp_pos + shift;
pos = shifted_grouped_pos;
}
PositionalEncodingQK(q, pos, layer_, query_scale, q);