disable broadcast on flash_attn_ext

This commit is contained in:
hongruichen 2025-07-12 11:46:02 +08:00
parent 560729ed6f
commit 4a3a87409b
1 changed files with 8 additions and 0 deletions

View File

@ -312,6 +312,14 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp
return false;
}
if (q->ne[2] != k->ne[2] || q->ne[3] != k->ne[3] || q->ne[3] != 1) {
// TODO: add broadcast support
DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n",
op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2],
k->ne[3]);
return false;
}
return true;
}