Use runtime config to setup self extend

This commit is contained in:
Nanubala Gnana Sai 2024-10-19 13:23:12 +05:30
parent fbba1972d0
commit f77e61e514
2 changed files with 24 additions and 17 deletions

View File

@ -127,6 +127,10 @@ struct LayerConfig {
size_t conv1d_width = 0; size_t conv1d_width = 0;
bool ff_biases = false; bool ff_biases = false;
bool softmax_attn_output_biases = false; bool softmax_attn_output_biases = false;
bool self_extend = false;
size_t ngb_size = 0;
size_t grp_size = 1;
PostNormType post_norm = PostNormType::None; PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma; LayerAttentionType type = LayerAttentionType::kGemma;
ActivationType activation = ActivationType::Gelu; ActivationType activation = ActivationType::Gelu;

View File

@ -312,28 +312,29 @@ class GemmaAttention {
const size_t interleaved_idx = task / layer_config_.kv_heads; const size_t interleaved_idx = task / layer_config_.kv_heads;
const size_t query_idx = interleaved_idx % num_queries_; const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_; const size_t batch_idx = interleaved_idx / num_queries_;
const size_t pos = queries_pos_[query_idx] + batch_idx; size_t pos = queries_pos_[query_idx] + batch_idx;
const size_t cache_pos = div_seq_len_.Remainder(pos); const size_t cache_pos = div_seq_len_.Remainder(pos);
const size_t kv_offset = cache_pos * cache_pos_size_ + const size_t kv_offset = cache_pos * cache_pos_size_ +
layer_ * cache_layer_size_ + layer_ * cache_layer_size_ +
head * layer_config_.qkv_dim * 2; head * layer_config_.qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx]; KVCache& kv_cache = kv_caches_[query_idx];
const size_t grp_size = layer_config_.grp_size;
const size_t ngb_size = layer_config_.ngb_size;
const bool self_extend = layer_config_.self_extend;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
const float* HWY_RESTRICT mha_kv = const float* HWY_RESTRICT mha_kv =
activations_.q.Batch(interleaved_idx) + head * q_stride_ + activations_.q.Batch(interleaved_idx) + head * q_stride_ +
layer_config_.qkv_dim; layer_config_.qkv_dim;
// When embedding position, we will use grouped key position
if (self_extend && pos > ngb_size) {
pos /= grp_size;
}
// Copy from `q` if MHA, or apply in-place. // Copy from `q` if MHA, or apply in-place.
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f, PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
kv); kv);
// When embedding position, we will use grouped key position
if constexpr (TConfig::kSelfExtend) {
if (pos > ngb_size) {
pos /= grp_size;
}
}
// If MHA, also copy V into KVCache. // If MHA, also copy V into KVCache.
if (is_mha_) { if (is_mha_) {
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim, hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
@ -418,19 +419,21 @@ class GemmaAttention {
const size_t batch_idx = interleaved_idx / num_queries_; const size_t batch_idx = interleaved_idx / num_queries_;
const size_t head_offset = const size_t head_offset =
(head / kHeadGroups) * layer_config_.qkv_dim * 2; (head / kHeadGroups) * layer_config_.qkv_dim * 2;
const size_t grp_size = layer_config_.grp_size;
const size_t ngb_size = layer_config_.ngb_size;
const bool self_extend = layer_config_.self_extend;
KVCache& kv_cache = kv_caches_[query_idx]; KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * q_stride_; activations_.q.Batch(interleaved_idx) + head * q_stride_;
// Apply rope and scaling to Q. // Apply rope and scaling to Q.
const size_t pos = queries_pos_[query_idx] + batch_idx; size_t pos = queries_pos_[query_idx] + batch_idx;
if constexpr (TConfig::kSelfExtend) { if (self_extend && pos > ngb_size) {
if (pos > ngb_size) { const size_t grp_pos = pos / grp_size;
const size_t grp_pos = pos / grp_size; const size_t shift = ngb_size - ngb_size / grp_size;
const size_t shift = ngb_size - ngb_size / grp_size; const size_t shifted_grouped_pos = grp_pos + shift;
const size_t shifted_grouped_pos = grp_pos + shift; pos = shifted_grouped_pos;
pos = shifted_grouped_pos;
}
} }
PositionalEncodingQK(q, pos, layer_, query_scale, q); PositionalEncodingQK(q, pos, layer_, query_scale, q);