From 2a6a95eb778a3dfaebafae2c3eaba6fe8b959c5c Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 13 Feb 2026 15:36:51 +0800 Subject: [PATCH] Fix MUL_MAT with broadcast; Add unsupported MUL_MAT FLASH_ATTN cases --- ggml/src/ggml-openvino/ggml-openvino.cpp | 37 +++++++++++++++++-- .../openvino/op/flash_attn_ext.cpp | 4 +- ggml/src/ggml-openvino/openvino/op/mulmat.cpp | 35 +++++++++--------- ggml/src/ggml-openvino/utils.cpp | 6 +-- 4 files changed, 56 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 6655db7298..780d17b750 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -763,7 +763,7 @@ static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_host_buffer_t return ggml_backend_openvino_host_buffer_type(ctx->device); } -static bool has_view_input(const ggml_tensor * op) { +static bool has_view_op_input(const ggml_tensor * op) { for (int i = 0; i < GGML_MAX_SRC; i++) { if (op->src[i] == nullptr) { break; @@ -775,6 +775,18 @@ static bool has_view_input(const ggml_tensor * op) { return false; } +static bool is_supported_flash_attn_pattern(const ggml_tensor * op) { + // pattern of q,k,v should be q->op==PERMUTE, q->src[0]->op==VIEW, q->src[0]->src[0]->view_src==nullptr + for (int i = 0; i < 3; i++) { + const ggml_tensor * src = op->src[i]; + if (src->op != GGML_OP_PERMUTE || src->src[0] == nullptr || src->src[0]->op != GGML_OP_VIEW || + src->src[0]->src[0] == nullptr || src->src[0]->src[0]->view_src != nullptr) { + return false; + } + } + return true; +} + static bool is_op_unsupported_case(const ggml_tensor * op) { switch (op->op) { case GGML_OP_GET_ROWS: @@ -814,6 +826,9 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n"); return true; } + if (!is_supported_flash_attn_pattern(op)) { + return true; + } float scale = 1.0f; float max_bias = 0.0f; float logit_softcap = 0.0f; @@ -852,6 +867,20 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { // GGML_LOG_WARN("OpenVINO backend does not support MUL_MAT with two F16 tensors\n"); return true; } + if (op->src[0]->ne[3] != op->src[1]->ne[3] && op->src[0]->ne[3] != 1 && op->src[1]->ne[3] != 1) { + return true; + } + if (op->src[0]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_PERMUTE) { + return true; + } + if (ggml_is_quantized(op->src[0]->type) && op->src[0]->ne[1] == 1) { + // MUL_MAT(type_a=q4_0,type_b=f32,m=1,n=2048,k=8192,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) + // triggers a bug in ov matmul_shape_inference.hpp + return true; + } + if (op->src[0]->op == GGML_OP_VIEW && op->src[1]->op == GGML_OP_VIEW) { + return true; + } break; } case GGML_OP_ROPE: { @@ -924,7 +953,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con // GGML_LOG_WARN("OpenVINO backend does not support unary op %s\n", ggml_unary_op_name(ggml_get_unary_op(op))); return false; } - if (has_view_input(op)) { + if (has_view_op_input(op)) { // GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n", // ggml_unary_op_name(ggml_get_unary_op(op))); return false; @@ -937,7 +966,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con // GGML_LOG_WARN("OpenVINO backend does not support GLU op %s\n", ggml_glu_op_name(ggml_get_glu_op(op))); return false; } - if (has_view_input(op)) { + if (has_view_op_input(op)) { // GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n", // ggml_glu_op_name(ggml_get_glu_op(op))); return false; @@ -954,7 +983,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con GGML_OP_GET_ROWS, GGML_OP_RMS_NORM, }; - if (ops_not_support_view_input.find(op->op) != ops_not_support_view_input.end() && has_view_input(op)) { + if (ops_not_support_view_input.find(op->op) != ops_not_support_view_input.end() && has_view_op_input(op)) { // GGML_LOG_WARN("OpenVINO backend does not support op %s with view input\n", ggml_op_name(op->op)); return false; } diff --git a/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp index 342da882aa..ca9e99ff88 100644 --- a/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +++ b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp @@ -55,13 +55,13 @@ OutputVector translate_flash_attn_ext(const NodeContext & context) { auto tile_kv = [&](int64_t num_heads, int64_t num_heads_kv, int64_t head_size, ov::Output kv) { int64_t factor = num_heads / num_heads_kv; - if (factor > 1) { + if (factor > 1 && num_heads_kv > 1) { ov::Output kv_broadcast_shape, kv_unsqueezed, new_kv_shape; auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2}); kv_unsqueezed = std::make_shared(kv, unsqueeze_axes); kv_broadcast_shape = ov::op::v0::Constant::create( - ov::element::i64, {5}, {(int64_t) 1, num_heads_kv, factor, (int64_t) 1, head_size}); + ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1}); new_kv_shape = ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 0, num_heads, (int64_t) -1, head_size}); diff --git a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp index 27e4bfa460..d2483e0ab0 100644 --- a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +++ b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp @@ -47,30 +47,31 @@ OutputVector translate_mulmat(const NodeContext & context) { auto B_shape = context.get_input_shape(0).to_shape(); auto A_shape = context.get_input_shape(1).to_shape(); - int64_t A_batch = A_shape[0]; - int64_t B_batch = B_shape[0]; + int64_t A_batch = A_shape[1]; + int64_t B_batch = B_shape[1]; + auto A_batch_larger = A_batch > B_batch; + auto batch_large = A_batch_larger ? A_batch : B_batch; + auto batch_small = A_batch_larger ? B_batch : A_batch; + Output Z = A_batch_larger ? B : A; - int64_t factor = A_batch_larger ? A_batch / B_batch : B_batch / A_batch; - if (factor > 1) { - // TODO code is outdated - auto A_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{A_batch}); - auto B_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{B_batch}); + int64_t factor = batch_large / batch_small; + if (factor > 1 && batch_small > 1) { + auto batch_large_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{batch_large}); + auto batch_small_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{batch_small}); auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{factor}); - auto Z_last_two_dims = get_dimensions(Z.get_node_shared_ptr(), {1, 2}); - - auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1}); + auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2}); auto Z_unsqueezed = std::make_shared(Z, unsqueeze_axes); - Output batch_small = A_batch_larger ? B_batch_node : A_batch_node; - Output batch_large = A_batch_larger ? A_batch_node : B_batch_node; - auto broadcast_shape = - std::make_shared(ov::OutputVector{batch_small, factor_node, Z_last_two_dims}, 0); - auto Z_broadcasted = std::make_shared(Z_unsqueezed, broadcast_shape); + auto broadcast_shape = ov::op::v0::Constant::create( + ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1}); + auto new_Z_shape = ov::op::v0::Constant::create(ov::element::i64, {4}, + {(int64_t) 0, batch_large, (int64_t) -1, (int64_t) A_shape[3]}); - auto new_Z_shape = std::make_shared(ov::OutputVector{batch_large, Z_last_two_dims}, 0); - Z = std::make_shared(Z_broadcasted, new_Z_shape, false); + auto Z_broadcasted = std::make_shared(Z_unsqueezed, broadcast_shape, + ov::op::BroadcastType::BIDIRECTIONAL); + Z = std::make_shared(Z_broadcasted, new_Z_shape, true); } if (A_batch_larger) { B = Z; diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index e79f582939..69fcb0eda4 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -54,13 +54,13 @@ enum ggml_status ov_graph_compute(ggml_cgraph * cgraph) { return is_static ? ov_graph_compute_static(cgraph) : ov_graph_compute_dynamic(cgraph, device, stateful); } catch (const ov::Exception & e) { - // GGML_LOG_ERROR("GGML OpenVINO backend ov::Exception: %s\n", e.what()); + GGML_LOG_ERROR("GGML OpenVINO backend ov::Exception: %s\n", e.what()); return GGML_STATUS_FAILED; } catch (const std::exception & e) { - // GGML_LOG_ERROR("GGML OpenVINO backend std::exception: %s\n", e.what()); + GGML_LOG_ERROR("GGML OpenVINO backend std::exception: %s\n", e.what()); return GGML_STATUS_FAILED; } catch (...) { - // GGML_LOG_ERROR("GGML OpenVINO backend unknown exception\n"); + GGML_LOG_ERROR("GGML OpenVINO backend unknown exception\n"); return GGML_STATUS_FAILED; } }