diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index ea5aca0..9fd739a 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -319,7 +319,7 @@ class GemmaAttention { head * layer_config_.qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; - const size_t grp_size = layer_config_.grp_size; + const hwy::Divisor& div_grp_size { static_cast(layer_config_.grp_size) }; const size_t ngb_size = layer_config_.ngb_size; const bool self_extend = layer_config_.self_extend; @@ -330,7 +330,7 @@ class GemmaAttention { // When embedding position, we will use grouped key position if (self_extend && pos > ngb_size) { - pos /= grp_size; + pos = div_grp_size.Divide(pos); } // Copy from `q` if MHA, or apply in-place. PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,