disable broadcast on flash_attn_ext
This commit is contained in:
parent
560729ed6f
commit
4a3a87409b
|
|
@ -312,6 +312,14 @@ bool is_flash_attn_supported(npu_device_tensor_op op, const npu_device_tensor_sp
|
|||
return false;
|
||||
}
|
||||
|
||||
if (q->ne[2] != k->ne[2] || q->ne[3] != k->ne[3] || q->ne[3] != 1) {
|
||||
// TODO: add broadcast support
|
||||
DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n",
|
||||
op_get_name(op), q->ne[0], q->ne[1], q->ne[2], q->ne[3], k->ne[0], k->ne[1], k->ne[2],
|
||||
k->ne[3]);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue