From 4a3a87409b244ae11984f0e1465aa87f5acd0503 Mon Sep 17 00:00:00 2001 From: hongruichen Date: Sat, 12 Jul 2025 11:46:02 +0800 Subject: [PATCH] disable broadcast on flash_attn_ext --- ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp index 1c5ccd9001..9c264654c1 100644 --- a/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp +++ b/ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp @@ -312,6 +312,14 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp return false; } + if (q->ne[2] != k->ne[2] || q->ne[3] != k->ne[3] || q->ne[3] != 1) { + // TODO: add broadcast support + DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n", + op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2], + k->ne[3]); + return false; + } + return true; }