Fix asan failure in local attention computation.

PiperOrigin-RevId: 670207380
This commit is contained in:
Zoltan Szabadka 2024-09-02 07:05:46 -07:00 committed by Copybara-Service
parent 22d9476aad
commit f6abbab3a4
1 changed files with 2 additions and 2 deletions

View File

@ -322,7 +322,7 @@ class GemmaAttention {
HWY_INLINE void QDotK(const size_t start_pos, const size_t pos,
const size_t head_offset, const float* HWY_RESTRICT q,
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
if (HWY_LIKELY(pos <= kSeqLen)) {
if (HWY_LIKELY(pos < kSeqLen)) {
// Slightly faster: no wraparound.
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t kv_offset =
@ -355,7 +355,7 @@ class GemmaAttention {
float* HWY_RESTRICT att_out) {
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
if (HWY_LIKELY(pos <= kSeqLen)) {
if (HWY_LIKELY(pos < kSeqLen)) {
// Slightly faster: no wraparound.
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t kv_offset =