fix unit test failure

This commit is contained in:
hongruichen 2025-07-12 00:39:14 +08:00
parent b720e47606
commit 560729ed6f
1 changed files with 3 additions and 1 deletions

View File

@ -105,7 +105,9 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
}
const npu_device_fp16_t * mp =
mask_ptr ? reinterpret_cast<const npu_device_fp16_t *>(mask_ptr + iq1 * mask->get_nb(1)) : nullptr;
mask_ptr ? reinterpret_cast<const npu_device_fp16_t *>(mask_ptr + iq1 * mask->get_nb(1) +
(iq3 % mask->get_ne(2)) * mask->get_nb(2)) :
nullptr;
// k indices
const int ik3 = iq3 / rk3;