fix unit test failure
This commit is contained in:
parent
b720e47606
commit
560729ed6f
|
|
@ -105,7 +105,9 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex
|
|||
}
|
||||
|
||||
const npu_device_fp16_t * mp =
|
||||
mask_ptr ? reinterpret_cast<const npu_device_fp16_t *>(mask_ptr + iq1 * mask->get_nb(1)) : nullptr;
|
||||
mask_ptr ? reinterpret_cast<const npu_device_fp16_t *>(mask_ptr + iq1 * mask->get_nb(1) +
|
||||
(iq3 % mask->get_ne(2)) * mask->get_nb(2)) :
|
||||
nullptr;
|
||||
|
||||
// k indices
|
||||
const int ik3 = iq3 / rk3;
|
||||
|
|
|
|||
Loading…
Reference in New Issue