From 6c517f151d6b3b4729199b62aad21781ba23807c Mon Sep 17 00:00:00 2001 From: hongruichen Date: Sat, 9 Aug 2025 10:47:06 +0800 Subject: [PATCH] feat: add sinks tensor support in fa impl --- .../src/ggml-qnn/npu/device/op_flash_attn.cpp | 43 +++++++++++++------ ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl | 2 +- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp index b721d06f37..776bcb74f3 100644 --- a/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp @@ -19,6 +19,7 @@ void flash_attn_impl(hexagon::tensor * out, const hexagon::tensor * k, const hexagon::tensor * v, const hexagon::tensor * mask, + const hexagon::tensor * sinks, hexagon::compute_params * params) { static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count"); @@ -92,11 +93,12 @@ void flash_attn_impl(hexagon::tensor * out, } DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_WITH_MULTI_SUB_PROC(out, params->get_thread_index(), flash_attn); - const uint8_t * q_ptr = q->get_read_buffer(); - const uint8_t * k_ptr = k->get_read_buffer(); - const uint8_t * v_ptr = v->get_read_buffer(); - const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr; - float * VKQ32 = reinterpret_cast(cache_ptr); // FP32 VKQ accumulator + const uint8_t * q_ptr = q->get_read_buffer(); + const uint8_t * k_ptr = k->get_read_buffer(); + const uint8_t * v_ptr = v->get_read_buffer(); + const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr; + const uint8_t * sinks_ptr = sinks ? sinks->get_read_buffer() : nullptr; + float * VKQ32 = reinterpret_cast(cache_ptr); // FP32 VKQ accumulator auto * VKQ16 = reinterpret_cast(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator auto * Q_q = reinterpret_cast( VKQ32 + 2 * aligned_dv); // (temporary) buffer for Q converted to quantized/FP16 @@ -224,6 +226,22 @@ void flash_attn_impl(hexagon::tensor * out, } } + if (sinks_ptr) { + const float s = reinterpret_cast(sinks_ptr)[h]; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + hexagon::vec_scale_f32(VKQ32, ms, VKQ32, DV); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + // V /= S const float S_inv = 1.0f / S; hexagon::vec_scale_f32(VKQ32, S_inv, VKQ32, DV); @@ -253,20 +271,21 @@ bool flash_attn_f32(tensor * out, compute_params * params) { return false; } - const auto * q = out->get_src(0); - const auto * k = out->get_src(1); - const auto * v = out->get_src(2); - const auto * mask = out->get_src(3); - if (!q || !k || !v || !mask) { + const auto * q = out->get_src(0); + const auto * k = out->get_src(1); + const auto * v = out->get_src(2); + if (!q || !k || !v) { DEVICE_LOG_DEBUG( "invalid src tensors: q: %p, k: %p, v: %p, mask: %p\n", (void *) q, (void *) k, (void *) v, (void *) mask); return false; } + const auto * mask = out->get_src(3); + const auto * sinks = out->get_src(4); if (k->get_type() == NPU_DATA_TYPE_F16) { - flash_attn_impl(out, q, k, v, mask, params); + flash_attn_impl(out, q, k, v, mask, sinks, params); } else { - flash_attn_impl(out, q, k, v, mask, params); + flash_attn_impl(out, q, k, v, mask, sinks, params); } return true; } diff --git a/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl b/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl index 1a9a4cb3a6..5aab3524c6 100644 --- a/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl +++ b/ggml/src/ggml-qnn/npu/idl/hexagon_npu.idl @@ -3,7 +3,7 @@ #include "remote.idl" const uint32_t DEVICE_TENSOR_MAX_DIMS = 4; -const uint32_t DEVICE_TENSOR_MAX_SRC = 4; +const uint32_t DEVICE_TENSOR_MAX_SRC = 5; const uint32_t DEVICE_TENSOR_MAX_OP_PARAMS = 16; const uint32_t QUANT_BLOCK_SIZE = 32; const uint32_t QUANT_K_BLOCK_SIZE = 256;