diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 6dc122b..cee376f 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -322,7 +322,7 @@ class GemmaAttention { HWY_INLINE void QDotK(const size_t start_pos, const size_t pos, const size_t head_offset, const float* HWY_RESTRICT q, const KVCache& kv_cache, float* HWY_RESTRICT head_att) { - if (HWY_LIKELY(pos <= kSeqLen)) { + if (HWY_LIKELY(pos < kSeqLen)) { // Slightly faster: no wraparound. for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { const size_t kv_offset = @@ -355,7 +355,7 @@ class GemmaAttention { float* HWY_RESTRICT att_out) { hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); - if (HWY_LIKELY(pos <= kSeqLen)) { + if (HWY_LIKELY(pos < kSeqLen)) { // Slightly faster: no wraparound. for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { const size_t kv_offset =