mirror of https://github.com/google/gemma.cpp.git
Move div_grp_size outside
This commit is contained in:
parent
3b270d236f
commit
719098fd3e
|
|
@ -305,6 +305,9 @@ class GemmaAttention {
|
|||
}
|
||||
}
|
||||
|
||||
// Self-extension
|
||||
const hwy::Divisor& div_grp_size{
|
||||
static_cast<uint32_t>(layer_config_.grp_size)};
|
||||
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||
pool_.Run(0, layer_config_.kv_heads * num_interleaved,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
|
|
@ -319,7 +322,6 @@ class GemmaAttention {
|
|||
head * layer_config_.qkv_dim * 2;
|
||||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
|
||||
const hwy::Divisor& div_grp_size { static_cast<uint32_t>(layer_config_.grp_size) };
|
||||
const size_t ngb_size = layer_config_.ngb_size;
|
||||
const bool self_extend = layer_config_.self_extend;
|
||||
|
||||
|
|
@ -328,7 +330,8 @@ class GemmaAttention {
|
|||
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
|
||||
layer_config_.qkv_dim;
|
||||
|
||||
// When embedding position, we will use grouped key position
|
||||
// In self-extend, when embedding position,
|
||||
// we will use grouped key position
|
||||
if (self_extend && pos > ngb_size) {
|
||||
pos = div_grp_size.Divide(pos);
|
||||
}
|
||||
|
|
@ -1484,7 +1487,7 @@ void GenerateBatchT(const ModelWeightsStorage& model,
|
|||
qbatch_size);
|
||||
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
||||
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
||||
qbatch_size);
|
||||
qbatch_size);
|
||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||
GenerateT<T>(model, activations, runtime_config, qbatch_prompts, qbatch_pos,
|
||||
qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info);
|
||||
|
|
|
|||
Loading…
Reference in New Issue