diff --git a/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp index af0a122a7e..1c5ccd9001 100644 --- a/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp @@ -105,7 +105,9 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * q, const hex } const npu_device_fp16_t * mp = - mask_ptr ? reinterpret_cast(mask_ptr + iq1 * mask->get_nb(1)) : nullptr; + mask_ptr ? reinterpret_cast(mask_ptr + iq1 * mask->get_nb(1) + + (iq3 % mask->get_ne(2)) * mask->get_nb(2)) : + nullptr; // k indices const int ik3 = iq3 / rk3;