mirror of https://github.com/google/gemma.cpp.git
Use runtime config to setup self extend
This commit is contained in:
parent
fbba1972d0
commit
f77e61e514
|
|
@ -127,6 +127,10 @@ struct LayerConfig {
|
||||||
size_t conv1d_width = 0;
|
size_t conv1d_width = 0;
|
||||||
bool ff_biases = false;
|
bool ff_biases = false;
|
||||||
bool softmax_attn_output_biases = false;
|
bool softmax_attn_output_biases = false;
|
||||||
|
bool self_extend = false;
|
||||||
|
size_t ngb_size = 0;
|
||||||
|
size_t grp_size = 1;
|
||||||
|
|
||||||
PostNormType post_norm = PostNormType::None;
|
PostNormType post_norm = PostNormType::None;
|
||||||
LayerAttentionType type = LayerAttentionType::kGemma;
|
LayerAttentionType type = LayerAttentionType::kGemma;
|
||||||
ActivationType activation = ActivationType::Gelu;
|
ActivationType activation = ActivationType::Gelu;
|
||||||
|
|
|
||||||
|
|
@ -312,28 +312,29 @@ class GemmaAttention {
|
||||||
const size_t interleaved_idx = task / layer_config_.kv_heads;
|
const size_t interleaved_idx = task / layer_config_.kv_heads;
|
||||||
const size_t query_idx = interleaved_idx % num_queries_;
|
const size_t query_idx = interleaved_idx % num_queries_;
|
||||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||||
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
const size_t cache_pos = div_seq_len_.Remainder(pos);
|
||||||
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
const size_t kv_offset = cache_pos * cache_pos_size_ +
|
||||||
layer_ * cache_layer_size_ +
|
layer_ * cache_layer_size_ +
|
||||||
head * layer_config_.qkv_dim * 2;
|
head * layer_config_.qkv_dim * 2;
|
||||||
KVCache& kv_cache = kv_caches_[query_idx];
|
KVCache& kv_cache = kv_caches_[query_idx];
|
||||||
|
|
||||||
|
const size_t grp_size = layer_config_.grp_size;
|
||||||
|
const size_t ngb_size = layer_config_.ngb_size;
|
||||||
|
const bool self_extend = layer_config_.self_extend;
|
||||||
|
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
const float* HWY_RESTRICT mha_kv =
|
const float* HWY_RESTRICT mha_kv =
|
||||||
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
|
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
|
||||||
layer_config_.qkv_dim;
|
layer_config_.qkv_dim;
|
||||||
|
|
||||||
|
// When embedding position, we will use grouped key position
|
||||||
|
if (self_extend && pos > ngb_size) {
|
||||||
|
pos /= grp_size;
|
||||||
|
}
|
||||||
// Copy from `q` if MHA, or apply in-place.
|
// Copy from `q` if MHA, or apply in-place.
|
||||||
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
|
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
|
||||||
kv);
|
kv);
|
||||||
|
|
||||||
// When embedding position, we will use grouped key position
|
|
||||||
if constexpr (TConfig::kSelfExtend) {
|
|
||||||
if (pos > ngb_size) {
|
|
||||||
pos /= grp_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If MHA, also copy V into KVCache.
|
// If MHA, also copy V into KVCache.
|
||||||
if (is_mha_) {
|
if (is_mha_) {
|
||||||
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
|
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
|
||||||
|
|
@ -418,20 +419,22 @@ class GemmaAttention {
|
||||||
const size_t batch_idx = interleaved_idx / num_queries_;
|
const size_t batch_idx = interleaved_idx / num_queries_;
|
||||||
const size_t head_offset =
|
const size_t head_offset =
|
||||||
(head / kHeadGroups) * layer_config_.qkv_dim * 2;
|
(head / kHeadGroups) * layer_config_.qkv_dim * 2;
|
||||||
|
|
||||||
|
const size_t grp_size = layer_config_.grp_size;
|
||||||
|
const size_t ngb_size = layer_config_.ngb_size;
|
||||||
|
const bool self_extend = layer_config_.self_extend;
|
||||||
KVCache& kv_cache = kv_caches_[query_idx];
|
KVCache& kv_cache = kv_caches_[query_idx];
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations_.q.Batch(interleaved_idx) + head * q_stride_;
|
activations_.q.Batch(interleaved_idx) + head * q_stride_;
|
||||||
|
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||||
if constexpr (TConfig::kSelfExtend) {
|
if (self_extend && pos > ngb_size) {
|
||||||
if (pos > ngb_size) {
|
|
||||||
const size_t grp_pos = pos / grp_size;
|
const size_t grp_pos = pos / grp_size;
|
||||||
const size_t shift = ngb_size - ngb_size / grp_size;
|
const size_t shift = ngb_size - ngb_size / grp_size;
|
||||||
const size_t shifted_grouped_pos = grp_pos + shift;
|
const size_t shifted_grouped_pos = grp_pos + shift;
|
||||||
pos = shifted_grouped_pos;
|
pos = shifted_grouped_pos;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
PositionalEncodingQK(q, pos, layer_, query_scale, q);
|
PositionalEncodingQK(q, pos, layer_, query_scale, q);
|
||||||
|
|
||||||
const size_t start_pos = StartPos(pos, layer_);
|
const size_t start_pos = StartPos(pos, layer_);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue