Fix MUL_MAT with broadcast; Add unsupported MUL_MAT FLASH_ATTN cases

This commit is contained in:
Yu, Zijun 2026-02-13 15:36:51 +08:00
parent 1a54965c43
commit 2a6a95eb77
4 changed files with 56 additions and 26 deletions

View File

@ -763,7 +763,7 @@ static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_host_buffer_t
return ggml_backend_openvino_host_buffer_type(ctx->device);
}
static bool has_view_input(const ggml_tensor * op) {
static bool has_view_op_input(const ggml_tensor * op) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (op->src[i] == nullptr) {
break;
@ -775,6 +775,18 @@ static bool has_view_input(const ggml_tensor * op) {
return false;
}
static bool is_supported_flash_attn_pattern(const ggml_tensor * op) {
// pattern of q,k,v should be q->op==PERMUTE, q->src[0]->op==VIEW, q->src[0]->src[0]->view_src==nullptr
for (int i = 0; i < 3; i++) {
const ggml_tensor * src = op->src[i];
if (src->op != GGML_OP_PERMUTE || src->src[0] == nullptr || src->src[0]->op != GGML_OP_VIEW ||
src->src[0]->src[0] == nullptr || src->src[0]->src[0]->view_src != nullptr) {
return false;
}
}
return true;
}
static bool is_op_unsupported_case(const ggml_tensor * op) {
switch (op->op) {
case GGML_OP_GET_ROWS:
@ -814,6 +826,9 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
// GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n");
return true;
}
if (!is_supported_flash_attn_pattern(op)) {
return true;
}
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
@ -852,6 +867,20 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
// GGML_LOG_WARN("OpenVINO backend does not support MUL_MAT with two F16 tensors\n");
return true;
}
if (op->src[0]->ne[3] != op->src[1]->ne[3] && op->src[0]->ne[3] != 1 && op->src[1]->ne[3] != 1) {
return true;
}
if (op->src[0]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_PERMUTE) {
return true;
}
if (ggml_is_quantized(op->src[0]->type) && op->src[0]->ne[1] == 1) {
// MUL_MAT(type_a=q4_0,type_b=f32,m=1,n=2048,k=8192,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1)
// triggers a bug in ov matmul_shape_inference.hpp
return true;
}
if (op->src[0]->op == GGML_OP_VIEW && op->src[1]->op == GGML_OP_VIEW) {
return true;
}
break;
}
case GGML_OP_ROPE: {
@ -924,7 +953,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
// GGML_LOG_WARN("OpenVINO backend does not support unary op %s\n", ggml_unary_op_name(ggml_get_unary_op(op)));
return false;
}
if (has_view_input(op)) {
if (has_view_op_input(op)) {
// GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n",
// ggml_unary_op_name(ggml_get_unary_op(op)));
return false;
@ -937,7 +966,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
// GGML_LOG_WARN("OpenVINO backend does not support GLU op %s\n", ggml_glu_op_name(ggml_get_glu_op(op)));
return false;
}
if (has_view_input(op)) {
if (has_view_op_input(op)) {
// GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n",
// ggml_glu_op_name(ggml_get_glu_op(op)));
return false;
@ -954,7 +983,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
GGML_OP_GET_ROWS,
GGML_OP_RMS_NORM,
};
if (ops_not_support_view_input.find(op->op) != ops_not_support_view_input.end() && has_view_input(op)) {
if (ops_not_support_view_input.find(op->op) != ops_not_support_view_input.end() && has_view_op_input(op)) {
// GGML_LOG_WARN("OpenVINO backend does not support op %s with view input\n", ggml_op_name(op->op));
return false;
}

View File

@ -55,13 +55,13 @@ OutputVector translate_flash_attn_ext(const NodeContext & context) {
auto tile_kv = [&](int64_t num_heads, int64_t num_heads_kv, int64_t head_size, ov::Output<Node> kv) {
int64_t factor = num_heads / num_heads_kv;
if (factor > 1) {
if (factor > 1 && num_heads_kv > 1) {
ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape;
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
kv_broadcast_shape = ov::op::v0::Constant::create(
ov::element::i64, {5}, {(int64_t) 1, num_heads_kv, factor, (int64_t) 1, head_size});
ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1});
new_kv_shape =
ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 0, num_heads, (int64_t) -1, head_size});

View File

@ -47,30 +47,31 @@ OutputVector translate_mulmat(const NodeContext & context) {
auto B_shape = context.get_input_shape(0).to_shape();
auto A_shape = context.get_input_shape(1).to_shape();
int64_t A_batch = A_shape[0];
int64_t B_batch = B_shape[0];
int64_t A_batch = A_shape[1];
int64_t B_batch = B_shape[1];
auto A_batch_larger = A_batch > B_batch;
auto batch_large = A_batch_larger ? A_batch : B_batch;
auto batch_small = A_batch_larger ? B_batch : A_batch;
Output<Node> Z = A_batch_larger ? B : A;
int64_t factor = A_batch_larger ? A_batch / B_batch : B_batch / A_batch;
if (factor > 1) {
// TODO code is outdated
auto A_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{A_batch});
auto B_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{B_batch});
int64_t factor = batch_large / batch_small;
if (factor > 1 && batch_small > 1) {
auto batch_large_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{batch_large});
auto batch_small_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{batch_small});
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
auto Z_last_two_dims = get_dimensions(Z.get_node_shared_ptr(), {1, 2});
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
auto Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
Output<Node> batch_small = A_batch_larger ? B_batch_node : A_batch_node;
Output<Node> batch_large = A_batch_larger ? A_batch_node : B_batch_node;
auto broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dims}, 0);
auto Z_broadcasted = std::make_shared<ov::op::v3::Broadcast>(Z_unsqueezed, broadcast_shape);
auto broadcast_shape = ov::op::v0::Constant::create(
ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1});
auto new_Z_shape = ov::op::v0::Constant::create(ov::element::i64, {4},
{(int64_t) 0, batch_large, (int64_t) -1, (int64_t) A_shape[3]});
auto new_Z_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_large, Z_last_two_dims}, 0);
Z = std::make_shared<ov::op::v1::Reshape>(Z_broadcasted, new_Z_shape, false);
auto Z_broadcasted = std::make_shared<ov::op::v3::Broadcast>(Z_unsqueezed, broadcast_shape,
ov::op::BroadcastType::BIDIRECTIONAL);
Z = std::make_shared<ov::op::v1::Reshape>(Z_broadcasted, new_Z_shape, true);
}
if (A_batch_larger) {
B = Z;

View File

@ -54,13 +54,13 @@ enum ggml_status ov_graph_compute(ggml_cgraph * cgraph) {
return is_static ? ov_graph_compute_static(cgraph) : ov_graph_compute_dynamic(cgraph, device, stateful);
} catch (const ov::Exception & e) {
// GGML_LOG_ERROR("GGML OpenVINO backend ov::Exception: %s\n", e.what());
GGML_LOG_ERROR("GGML OpenVINO backend ov::Exception: %s\n", e.what());
return GGML_STATUS_FAILED;
} catch (const std::exception & e) {
// GGML_LOG_ERROR("GGML OpenVINO backend std::exception: %s\n", e.what());
GGML_LOG_ERROR("GGML OpenVINO backend std::exception: %s\n", e.what());
return GGML_STATUS_FAILED;
} catch (...) {
// GGML_LOG_ERROR("GGML OpenVINO backend unknown exception\n");
GGML_LOG_ERROR("GGML OpenVINO backend unknown exception\n");
return GGML_STATUS_FAILED;
}
}