From 1a54965c439f5b1e7f71d4e3b0232f61d8d36139 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 13 Feb 2026 10:29:48 +0800 Subject: [PATCH] Suppress logging and add error handling to allow test-backend-ops to complete --- ggml/src/ggml-openvino/ggml-decoder.cpp | 6 +- .../src/ggml-openvino/ggml-openvino-extra.cpp | 2 +- ggml/src/ggml-openvino/ggml-openvino.cpp | 94 +++++++++---------- ggml/src/ggml-openvino/openvino/op/rope.cpp | 4 +- ggml/src/ggml-openvino/utils.cpp | 37 +++++--- 5 files changed, 76 insertions(+), 67 deletions(-) diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 8796c23abd..99776e1bec 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -568,20 +568,20 @@ std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor * tensor // F16/F32/BF16 weight with shared-memory constant auto * weight_extra = static_cast(tensor->extra); if (weight_extra->weight_node) { - GGML_LOG_DEBUG("%s: using pre-built weight node for %s\n", __func__, tensor->name); + // GGML_LOG_DEBUG("%s: using pre-built weight node for %s\n", __func__, tensor->name); return weight_extra->weight_node; } } else if (extra_base->type == ggml_openvino_extra_base::Type::QUANTIZED_WEIGHT) { // Quantized weight with pre-extracted data auto * quant_extra = static_cast(tensor->extra); if (quant_extra->weight_node) { - GGML_LOG_DEBUG("%s: using pre-extracted quantized weight node for %s\n", __func__, tensor->name); + // GGML_LOG_DEBUG("%s: using pre-extracted quantized weight node for %s\n", __func__, tensor->name); return quant_extra->weight_node; } } } - GGML_LOG_DEBUG("%s: creating new weight node for %s\n", __func__, tensor->name); + // GGML_LOG_DEBUG("%s: creating new weight node for %s\n", __func__, tensor->name); static const std::set weight_types = {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K}; diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index 39bf7610eb..7a48ed1b65 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -348,7 +348,7 @@ ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor element_type = ov::element::i64; break; default: - GGML_LOG_ERROR("%s: unsupported tensor type for ov::Tensor: %s\n", __func__, ggml_type_name(tensor->type)); + // GGML_LOG_WARN("%s: unsupported tensor type for ov::Tensor: %s\n", __func__, ggml_type_name(tensor->type)); return nullptr; } diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 801d9ad5c4..6655db7298 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -223,7 +223,7 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; - // Check if this is a weight buffer (usage is set BEFORE set_tensor is called) + // Check if this is a weight buffer (usage is set BEFORE set_tensor is called, except in test-backend-ops) bool is_weight_buffer = (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS); // Full tensor set: offset=0, full size, not a view bool is_full_tensor_set = (offset == 0 && size == ggml_nbytes(tensor) && tensor->view_src == nullptr); @@ -235,7 +235,7 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer auto result = process_weight_tensor(tensor, data, tensor->data); result.weight_node->set_friendly_name(tensor->name); - const auto & layout = result.layout; + // const auto & layout = result.layout; ggml_openvino_extra_base * extra; // Quantized path with extracted weight/scale/zp tensors @@ -243,24 +243,24 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer extra = new ggml_openvino_quantized_weight_extra(std::move(result.weights), std::move(result.scales), std::move(result.zp), result.weight_node); - if (layout.is_requant) { - GGML_LOG_DEBUG("%s: requantized %s to %s (u%d, block_size=%ld)\n", __func__, tensor->name, - extra_quant_type_name(layout.requant_type.value()), layout.is_u4 ? 4 : 8, - layout.weights_per_block); - } else { - int64_t n_blocks = ggml_nelements(tensor) / layout.weights_per_block; - GGML_LOG_DEBUG("%s: extracted quantized weight node for %s (u%d, %zu weights, %ld blocks)\n", - __func__, tensor->name, layout.is_u4 ? 4 : 8, layout.weights_size, n_blocks); - } + // if (layout.is_requant) { + // GGML_LOG_DEBUG("%s: requantized %s to %s (u%d, block_size=%ld)\n", __func__, tensor->name, + // extra_quant_type_name(layout.requant_type.value()), layout.is_u4 ? 4 : 8, + // layout.weights_per_block); + // } else { + // int64_t n_blocks = ggml_nelements(tensor) / layout.weights_per_block; + // GGML_LOG_DEBUG("%s: extracted quantized weight node for %s (u%d, %zu weights, %ld blocks)\n", + // __func__, tensor->name, layout.is_u4 ? 4 : 8, layout.weights_size, n_blocks); + // } } else { // F16/F32/BF16 weight or F16-requant extra = new ggml_openvino_weight_extra(std::move(result.weights), result.weight_node); - if (layout.total_size > 0) { - GGML_LOG_DEBUG("%s: requantized %s to F16\n", __func__, tensor->name); - } else { - GGML_LOG_DEBUG("%s: created shared-memory weight node for %s\n", __func__, tensor->name); - } + // if (layout.total_size > 0) { + // GGML_LOG_DEBUG("%s: requantized %s to F16\n", __func__, tensor->name); + // } else { + // GGML_LOG_DEBUG("%s: created shared-memory weight node for %s\n", __func__, tensor->name); + // } } ctx->tensor_extras[tensor] = extra; @@ -271,7 +271,7 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer memcpy((char *) tensor->data + offset, data, size); } } else { - // Non-weight tensor (KV cache, activations, etc.) - copy data + // Non-weight tensor (KV cache, activations, etc.) - copy data. test-backend-ops also goes here if (ctx->is_remote) { cl_command_queue queue = ggml_openvino_get_cl_queue(); auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); @@ -290,7 +290,7 @@ static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor, ctx->is_remote); if (extra == nullptr) { - GGML_LOG_ERROR("%s: failed to create tensor extra for %s\n", __func__, tensor->name); + // GGML_LOG_ERROR("%s: failed to create tensor extra for %s\n", __func__, tensor->name); return; } @@ -795,7 +795,7 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { } case GGML_OP_SOFT_MAX: { if (op->src[2] != nullptr) { - GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with sinks\n"); + // GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with sinks\n"); return true; } float scale = 1.0f; @@ -804,14 +804,14 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { memcpy(&scale, (const float *) op_params + 0, sizeof(float)); memcpy(&max_bias, (const float *) op_params + 1, sizeof(float)); if (max_bias > 0) { - GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with max_bias > 0\n"); + // GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with max_bias > 0\n"); return true; } break; } case GGML_OP_FLASH_ATTN_EXT: { if (op->src[4] != nullptr) { - GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n"); + // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n"); return true; } float scale = 1.0f; @@ -822,11 +822,11 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { memcpy(&max_bias, (const float *) op_params + 1, sizeof(float)); memcpy(&logit_softcap, (const float *) op_params + 2, sizeof(float)); if (max_bias > 0) { - GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\n"); + // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\n"); return true; } if (logit_softcap != 0) { - GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\n"); + // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\n"); return true; } break; @@ -834,14 +834,14 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { case GGML_OP_PERMUTE: { if (op->type == GGML_TYPE_BF16) { // err msg: [GPU] Could not find a suitable kernel for transpose - GGML_LOG_WARN("OpenVINO backend does not support PERMUTE with BF16 type\n"); + // GGML_LOG_WARN("OpenVINO backend does not support PERMUTE with BF16 type\n"); return true; } break; } case GGML_OP_CPY: { if (op->src[1] != op) { - GGML_LOG_WARN("OpenVINO backend only supports CPY that is a cast\n"); + // GGML_LOG_WARN("OpenVINO backend only supports CPY that is a cast\n"); return true; } break; @@ -849,7 +849,7 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { case GGML_OP_MUL_MAT: { if (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16) { // Has accuracy issue, try enabling this and see `test-backend-ops -o "MUL_MAT"` - GGML_LOG_WARN("OpenVINO backend does not support MUL_MAT with two F16 tensors\n"); + // GGML_LOG_WARN("OpenVINO backend does not support MUL_MAT with two F16 tensors\n"); return true; } break; @@ -858,17 +858,17 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { const int32_t * op_params = op->op_params; const int n_dims = op_params[1]; const int mode = op_params[2]; - if (mode == GGML_ROPE_TYPE_MROPE || mode == GGML_ROPE_TYPE_VISION) { - GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode); + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + // GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode); return true; } if (n_dims != 0.0f && n_dims != op->src[0]->ne[0]) { - GGML_LOG_WARN("OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\n", n_dims, - op->src[0]->ne[0]); + // GGML_LOG_WARN("OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\n", n_dims, + // op->src[0]->ne[0]); return true; } if (op->type != GGML_TYPE_F32) { - GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type)); + // GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type)); return true; } float freq_scale; @@ -876,15 +876,15 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { memcpy(&freq_scale, op_params + 6, sizeof(float)); memcpy(&ext_factor, op_params + 7, sizeof(float)); if (ext_factor != 0.0f) { - GGML_LOG_WARN("OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\n", ext_factor); + // GGML_LOG_WARN("OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\n", ext_factor); return true; } if (op->src[0]->op == GGML_OP_VIEW) { if (op->src[0]->view_src->ne[1] != op->src[0]->ne[2]) { - GGML_LOG_WARN( - "OpenVINO backend does not support ROPE with src[0]->view_src->ne[1] %ld != src[0]->ne[2] " - "%ld\n", - op->src[0]->view_src->ne[1], op->src[0]->ne[2]); + // GGML_LOG_WARN( + // "OpenVINO backend does not support ROPE with src[0]->view_src->ne[1] %ld != src[0]->ne[2] " + // "%ld\n", + // op->src[0]->view_src->ne[1], op->src[0]->ne[2]); return true; } } @@ -921,12 +921,12 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con case GGML_OP_UNARY: { auto supported = supported_unary_ops.find(ggml_get_unary_op(op)) != supported_unary_ops.end(); if (!supported) { - GGML_LOG_WARN("OpenVINO backend does not support unary op %s\n", ggml_unary_op_name(ggml_get_unary_op(op))); + // GGML_LOG_WARN("OpenVINO backend does not support unary op %s\n", ggml_unary_op_name(ggml_get_unary_op(op))); return false; } if (has_view_input(op)) { - GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n", - ggml_unary_op_name(ggml_get_unary_op(op))); + // GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n", + // ggml_unary_op_name(ggml_get_unary_op(op))); return false; } break; @@ -934,12 +934,12 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con case GGML_OP_GLU: { auto supported = supported_glu_ops.find(ggml_get_glu_op(op)) != supported_glu_ops.end(); if (!supported) { - GGML_LOG_WARN("OpenVINO backend does not support GLU op %s\n", ggml_glu_op_name(ggml_get_glu_op(op))); + // GGML_LOG_WARN("OpenVINO backend does not support GLU op %s\n", ggml_glu_op_name(ggml_get_glu_op(op))); return false; } if (has_view_input(op)) { - GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n", - ggml_glu_op_name(ggml_get_glu_op(op))); + // GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n", + // ggml_glu_op_name(ggml_get_glu_op(op))); return false; } break; @@ -947,7 +947,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con default: { auto supported = supported_ops.find(op->op) != supported_ops.end(); if (!supported) { - GGML_LOG_WARN("OpenVINO backend does not support op %s\n", ggml_op_name(op->op)); + // GGML_LOG_WARN("OpenVINO backend does not support op %s\n", ggml_op_name(op->op)); return false; } static std::set ops_not_support_view_input{ @@ -955,14 +955,14 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con GGML_OP_RMS_NORM, }; if (ops_not_support_view_input.find(op->op) != ops_not_support_view_input.end() && has_view_input(op)) { - GGML_LOG_WARN("OpenVINO backend does not support op %s with view input\n", ggml_op_name(op->op)); + // GGML_LOG_WARN("OpenVINO backend does not support op %s with view input\n", ggml_op_name(op->op)); return false; } } } if (supported_types.find(op->type) == supported_types.end()) { - GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(op->type)); + // GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(op->type)); return false; } for (int i = 0; i < GGML_MAX_SRC; i++) { @@ -971,11 +971,11 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con break; } if (supported_types.find(src->type) == supported_types.end()) { - GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(src->type)); + // GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(src->type)); return false; } if (ggml_is_quantized(src->type) && src->ne[2] != 1) { - GGML_LOG_WARN("OpenVINO backend does not support 3D quantized tensors\n"); + // GGML_LOG_WARN("OpenVINO backend does not support 3D quantized tensors\n"); return false; } } diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index 44e3368217..22fb7e2ba2 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -66,10 +66,10 @@ OutputVector translate_rope(const NodeContext & context) { } const int mode = op_params[2]; + constexpr int ROPE_TYPE_NORMAL = 0; constexpr int ROPE_TYPE_NEOX = 2; - constexpr int ROPE_TYPE_NORM = 0; - if (mode == ROPE_TYPE_NORM) { + if (mode == ROPE_TYPE_NORMAL) { auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index a370043dd7..e79f582939 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -38,19 +38,31 @@ #pragma GCC diagnostic ignored "-Wdeprecated-declarations" enum ggml_status ov_graph_compute(ggml_cgraph * cgraph) { - if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { - std::string filename = "cgraph_ov.txt"; - GgmlOvDecoder::dump_cgraph(cgraph, filename); - } + try { + if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { + std::string filename = "cgraph_ov.txt"; + GgmlOvDecoder::dump_cgraph(cgraph, filename); + } - // Use device from singleton (initialized during backend init) - const auto & device = ggml_openvino_get_device_name(); - const auto is_static = ggml_openvino_is_npu(); - bool stateful = false; - if (getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !is_static) { - stateful = true; + // Use device from singleton (initialized during backend init) + const auto & device = ggml_openvino_get_device_name(); + const auto is_static = ggml_openvino_is_npu(); + bool stateful = false; + if (getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !is_static) { + stateful = true; + } + + return is_static ? ov_graph_compute_static(cgraph) : ov_graph_compute_dynamic(cgraph, device, stateful); + } catch (const ov::Exception & e) { + // GGML_LOG_ERROR("GGML OpenVINO backend ov::Exception: %s\n", e.what()); + return GGML_STATUS_FAILED; + } catch (const std::exception & e) { + // GGML_LOG_ERROR("GGML OpenVINO backend std::exception: %s\n", e.what()); + return GGML_STATUS_FAILED; + } catch (...) { + // GGML_LOG_ERROR("GGML OpenVINO backend unknown exception\n"); + return GGML_STATUS_FAILED; } - return is_static ? ov_graph_compute_static(cgraph) : ov_graph_compute_dynamic(cgraph, device, stateful); } enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::string & device, bool stateful) { @@ -454,9 +466,6 @@ enum ggml_status naive_compute(ggml_cgraph * cgraph, if (cgraph->n_nodes == 1 && (cgraph->nodes[0]->op == GGML_OP_NONE || cgraph->nodes[0]->op == GGML_OP_VIEW)) { return GGML_STATUS_SUCCESS; } - if (cgraph->nodes[0]->op == GGML_OP_FLASH_ATTN_EXT) { - return GGML_STATUS_FAILED; - } auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); auto decoder = std::make_shared(cgraph, model_weights);