From 91d2a195b56dd4967846951cf1dbaf646576438b Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 15 Apr 2025 14:34:00 +0800 Subject: [PATCH] change op mappings to list in openvino_supports_op --- ggml/src/ggml-openvino.cpp | 96 +++----------------------------- ggml/src/ggml-openvino/utils.cpp | 21 +++---- ggml/src/ggml-openvino/utils.h | 2 +- 3 files changed, 17 insertions(+), 102 deletions(-) diff --git a/ggml/src/ggml-openvino.cpp b/ggml/src/ggml-openvino.cpp index 762ed786a9..5ea2351e06 100644 --- a/ggml/src/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino.cpp @@ -1036,9 +1036,7 @@ static enum ggml_status ggml_backend_openvino_graph_compute(ggml_backend_t backe // Process nodes in order - bool prompt_process_flag = true; if (cgraph->nodes[0]->ne[1] == 1) { - prompt_process_flag = false; for (int i = 0; i < cgraph->n_nodes; i++) { if (std::find(add_indices.begin(), add_indices.end(), i) != add_indices.end()) { ggml_backend_openvino_add_forward(cgraph->nodes[i]); @@ -1066,13 +1064,13 @@ static enum ggml_status ggml_backend_openvino_graph_compute(ggml_backend_t backe i++; } if (start_index < i) { - openvino_frontend_compute(backend, cgraph, start_index, --i, prompt_process_flag); + openvino_frontend_compute(backend, cgraph, start_index, --i); } } } } else { int end_node = cgraph->n_nodes - 1; - openvino_frontend_compute(backend, cgraph, 0, end_node, prompt_process_flag); + openvino_frontend_compute(backend, cgraph, 0, end_node); } return GGML_STATUS_SUCCESS; @@ -1331,91 +1329,11 @@ static const std::set& openvino_ops = []() -> const std::set> op_mapping = { - {GGML_OP_ACC, {"Add"}}, - {GGML_OP_ADD, {"Add"}}, - {GGML_OP_ADD1, {"Add"}}, - {GGML_OP_ADD_REL_POS, {"Add", "MatMul", "Reshape"}}, - {GGML_OP_ARANGE, {"Range"}}, - {GGML_OP_ARGMAX, {"TopK"}}, - {GGML_OP_ARGSORT, {"TopK"}}, - {GGML_OP_CLAMP, {"Clamp"}}, - {GGML_OP_CONCAT, {"Concat"}}, - {GGML_OP_CONV_TRANSPOSE_1D, {"ConvolutionBackpropData"}}, - {GGML_OP_CONV_TRANSPOSE_2D, {"ConvolutionBackpropData"}}, - {GGML_OP_COS, {"Cos"}}, - {GGML_OP_CROSS_ENTROPY_LOSS, {"Softmax", "Log", "Multiply", "ReduceSum", "Negative"}}, - {GGML_OP_DIAG, {"Eye", "Multiply"}}, - {GGML_OP_DIAG_MASK_INF, {"Eye", "Multiply", "Select", "Broadcast"}}, - {GGML_OP_DIAG_MASK_ZERO, {"Eye", "Multiply", "Select", "Broadcast"}}, - {GGML_OP_DIV, {"Divide"}}, - {GGML_OP_FLASH_ATTN_EXT, {"ScaledDotProductAttention"}}, - {GGML_OP_GET_ROWS, {"Gather"}}, - {GGML_OP_GROUP_NORM, {"GroupNormalization"}}, - {GGML_OP_IM2COL, {"Custom", "Reshape", "Transpose"}}, - {GGML_OP_LEAKY_RELU, {"PReLU"}}, - {GGML_OP_LOG, {"Log"}}, - {GGML_OP_MEAN, {"ReduceMean"}}, - {GGML_OP_MUL, {"Multiply"}}, - {GGML_OP_MUL_MAT, {"MatMul"}}, - {GGML_OP_MUL_MAT_ID, {"MatMul", "Identity"}}, - {GGML_OP_NORM, {"NormalizeL2"}}, - {GGML_OP_OUT_PROD, {"MatMul", "Reshape"}}, - {GGML_OP_PAD, {"Pad"}}, - {GGML_OP_PERMUTE, {"Transpose"}}, - {GGML_OP_POOL_1D, {"AvgPool", "MaxPool"}}, - {GGML_OP_POOL_2D, {"AvgPool", "MaxPool"}}, - {GGML_OP_REPEAT, {"Tile"}}, - {GGML_OP_RESHAPE, {"Reshape"}}, - {GGML_OP_RMS_NORM, {"Multiply", "Divide", "Sqrt"}}, - {GGML_OP_ROPE, {"Sin", "Cos", "Multiply", "Add", "Subtract", "Split", "StridedSlice", "Concat"}}, - {GGML_OP_SCALE, {"Multiply", "Constant"}}, - {GGML_OP_SET, {"Assign"}}, - {GGML_OP_SIN, {"Sin"}}, - {GGML_OP_SOFT_MAX, {"Softmax"}}, - {GGML_OP_SQR, {"Power"}}, - {GGML_OP_SQRT, {"Sqrt"}}, - {GGML_OP_SSM_CONV, {"Custom"}}, - {GGML_OP_SSM_SCAN, {"Custom"}}, - {GGML_OP_SUB, {"Subtract"}}, - {GGML_OP_SUM, {"ReduceSum"}}, - {GGML_OP_SUM_ROWS, {"ReduceSum", "Squeeze", "Unsqueeze"}}, - {GGML_OP_TIMESTEP_EMBEDDING, {"Range", "Power", "Multiply", "Sin", "Cos", "Concat"}}, - {GGML_OP_TRANSPOSE, {"Transpose"}}, - {GGML_OP_UPSCALE, {"Interpolate"}}, - {GGML_OP_VIEW, {"Reshape"}}, - {GGML_OP_CONT, {"Reshape", "StridedSlice"}}, - {GGML_OP_CPY, {"Reshape", "ScatterNDUpdate"}}, - {GGML_OP_WIN_PART, {"StridedSlice", "Concat", "Reshape", "Custom"}}, - {GGML_OP_WIN_UNPART, {"Reshape", "Transpose", "Custom"}}, - }; - - static const std::map> op_mapping_unary = { - {GGML_UNARY_OP_SILU, {"Sigmoid", "Multiply"}}, - }; - - std::vector mapped_ops; - if (op->op == GGML_OP_UNARY) { - auto it = op_mapping_unary.find(ggml_get_unary_op(op)); - if (it == op_mapping_unary.end()) { - return false; - } - mapped_ops = it->second; - } else { - auto it = op_mapping.find(op->op); - if (it == op_mapping.end()) { - return false; - } - mapped_ops = it->second; - } - - for (const std::string& op_name : mapped_ops) { - if (openvino_ops.count(op_name) == 0) { - return false; - } - } - - return true; + if (op->op == GGML_OP_UNARY) { + return supported_unary_ops.find(ggml_get_unary_op(op)) != + supported_unary_ops.end(); + } + return supported_ops.find(op->op) != supported_ops.end(); } static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index f4d9c7705a..c32ad65842 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -1,6 +1,7 @@ #include "utils.h" #include "ggml-backend-impl.h" #include "ggml-impl.h" +#include "ggml.h" #include #include #include @@ -13,7 +14,7 @@ std::shared_ptr get_ggml_decoder(struct ggml_cgraph * cgraph, con return std::make_shared(nullptr, cgraph, start_index, end_index); } -std::vector> get_ggml_graph_input_tensors(std::shared_ptr ggml_decoder, bool flag) { +std::vector> get_ggml_graph_input_tensors(std::shared_ptr ggml_decoder) { std::vector> input_tensors; auto input_names = ggml_decoder->get_input_names(); size_t op_iter = 0; @@ -77,10 +78,13 @@ static ov::frontend::FrontEnd::Ptr get_ggml_frontend() { return front_end; } -enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph, const int32_t start_index, const int32_t end_index, bool flag) { +enum ggml_status openvino_frontend_compute(ggml_backend_t backend, + struct ggml_cgraph *cgraph, + const int32_t start_index, + const int32_t end_index) { static ov::Core core; + // auto devices = core.get_available_devices(); - // Get GGML Frontend static auto front_end = get_ggml_frontend(); if (!front_end) { GGML_LOG_ERROR("GGML FrontEnd is not initialized \n"); @@ -90,6 +94,7 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c GGML_LOG_INFO("GGML FrontEnd is initialized \n"); #endif } + auto ggml_decoder = get_ggml_decoder(cgraph, start_index, end_index); std::shared_ptr graph_decoder = ggml_decoder; // Load GraphIterator -> InputModel @@ -123,26 +128,18 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c } ov::CompiledModel compiled_model = core.compile_model(model); - - // Create infer request ov::InferRequest infer_request = compiled_model.create_infer_request(); - // Get input tensor auto input_names = ggml_decoder->get_input_names(); - auto input_tensors = get_ggml_graph_input_tensors(ggml_decoder, flag); - + auto input_tensors = get_ggml_graph_input_tensors(ggml_decoder); auto ov_params = model->get_parameters(); for (size_t i = 0; i < ov_params.size(); i++) { auto param_name = ov_params[i]->get_friendly_name(); infer_request.set_input_tensor(i, get_ggml_graph_input_tensor(ggml_decoder, param_name)); } - // for (size_t i = 0; i < input_names.size(); i++) { - // infer_request.set_input_tensor(i, input_tensors.at(i).second); - // } infer_request.infer(); - // Set dst data for outputs auto output_names = ggml_decoder->get_output_names(); auto output_tensors = get_ggml_graph_output_dst(ggml_decoder); for (size_t i = 0; i < output_names.size(); i++) { diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 7806c418cb..0f5617ab4b 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -1,4 +1,4 @@ #include "ggml-decoder.h" #include "ggml-backend-impl.h" -enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph, const int32_t start_index=0, const int32_t end_index=0, bool flag = true); +enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph, const int32_t start_index=0, const int32_t end_index=0);