mirror of https://github.com/google/gemma.cpp.git
Fix asan failure in local attention computation.
PiperOrigin-RevId: 670207380
This commit is contained in:
parent
22d9476aad
commit
f6abbab3a4
|
|
@ -322,7 +322,7 @@ class GemmaAttention {
|
||||||
HWY_INLINE void QDotK(const size_t start_pos, const size_t pos,
|
HWY_INLINE void QDotK(const size_t start_pos, const size_t pos,
|
||||||
const size_t head_offset, const float* HWY_RESTRICT q,
|
const size_t head_offset, const float* HWY_RESTRICT q,
|
||||||
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
|
const KVCache& kv_cache, float* HWY_RESTRICT head_att) {
|
||||||
if (HWY_LIKELY(pos <= kSeqLen)) {
|
if (HWY_LIKELY(pos < kSeqLen)) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
|
|
@ -355,7 +355,7 @@ class GemmaAttention {
|
||||||
float* HWY_RESTRICT att_out) {
|
float* HWY_RESTRICT att_out) {
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
|
|
||||||
if (HWY_LIKELY(pos <= kSeqLen)) {
|
if (HWY_LIKELY(pos < kSeqLen)) {
|
||||||
// Slightly faster: no wraparound.
|
// Slightly faster: no wraparound.
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue