Move div_grp_size outside

This commit is contained in:
Nanubala Gnana Sai 2024-11-01 19:38:58 +05:30
parent 3b270d236f
commit 719098fd3e
1 changed files with 6 additions and 3 deletions

View File

@ -305,6 +305,9 @@ class GemmaAttention {
}
}
// Self-extension
const hwy::Divisor& div_grp_size{
static_cast<uint32_t>(layer_config_.grp_size)};
// Apply positional encodings for K (and copy KV to cache if MHA).
pool_.Run(0, layer_config_.kv_heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
@ -319,7 +322,6 @@ class GemmaAttention {
head * layer_config_.qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx];
const hwy::Divisor& div_grp_size { static_cast<uint32_t>(layer_config_.grp_size) };
const size_t ngb_size = layer_config_.ngb_size;
const bool self_extend = layer_config_.self_extend;
@ -328,7 +330,8 @@ class GemmaAttention {
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
layer_config_.qkv_dim;
// When embedding position, we will use grouped key position
// In self-extend, when embedding position,
// we will use grouped key position
if (self_extend && pos > ngb_size) {
pos = div_grp_size.Divide(pos);
}
@ -1484,7 +1487,7 @@ void GenerateBatchT(const ModelWeightsStorage& model,
qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
qbatch_size);
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT<T>(model, activations, runtime_config, qbatch_prompts, qbatch_pos,
qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info);