From 8ce5cc597a5f18e4adfa090e6394cbfefbb458db Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 22 May 2025 10:32:18 +0800 Subject: [PATCH] Add cgraph tensor output name to OV op name --- ggml/src/ggml-openvino/openvino/op/add.cpp | 7 +++---- ggml/src/ggml-openvino/openvino/op/cont.cpp | 13 ++++++------ ggml/src/ggml-openvino/openvino/op/cpy.cpp | 10 ++++----- .../ggml-openvino/openvino/op/get_rows.cpp | 2 +- ggml/src/ggml-openvino/openvino/op/mul.cpp | 2 +- ggml/src/ggml-openvino/openvino/op/mulmat.cpp | 11 +++++----- .../src/ggml-openvino/openvino/op/permute.cpp | 2 +- .../src/ggml-openvino/openvino/op/reshape.cpp | 4 ++-- .../ggml-openvino/openvino/op/rms_norm.cpp | 2 +- ggml/src/ggml-openvino/openvino/op/rope.cpp | 11 +++++----- ggml/src/ggml-openvino/openvino/op/scale.cpp | 2 +- .../ggml-openvino/openvino/op/soft_max.cpp | 21 ++++++++++--------- .../ggml-openvino/openvino/op/transpose.cpp | 2 +- .../ggml-openvino/openvino/op/unary_silu.cpp | 2 +- .../openvino/translate_session.cpp | 14 +++++++++++-- ggml/src/ggml-openvino/openvino/utils.cpp | 11 ++++++++++ ggml/src/ggml-openvino/openvino/utils.hpp | 2 ++ 17 files changed, 71 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-openvino/openvino/op/add.cpp b/ggml/src/ggml-openvino/openvino/op/add.cpp index 18bc463fb9..5a75ff2148 100644 --- a/ggml/src/ggml-openvino/openvino/op/add.cpp +++ b/ggml/src/ggml-openvino/openvino/op/add.cpp @@ -11,10 +11,9 @@ namespace op { OutputVector translate_add(const NodeContext& context) { num_inputs_check(context, 2, 2); - auto lhs = context.get_input(0); - auto rhs = context.get_input(1); - auto add = std::make_shared(lhs, rhs); - return {add}; + auto res = std::make_shared(context.get_input(0), context.get_input(1)); + + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/cont.cpp b/ggml/src/ggml-openvino/openvino/op/cont.cpp index a052bf06ca..7cdfba051e 100644 --- a/ggml/src/ggml-openvino/openvino/op/cont.cpp +++ b/ggml/src/ggml-openvino/openvino/op/cont.cpp @@ -22,16 +22,15 @@ OutputVector translate_cont(const NodeContext& context) { auto src_shape = context.get_input_shape(0).to_shape(); auto dst_shape = context.get_output_shape(0).to_shape(); + ov::Output res; if (op_case == 1) { // The input comes from a PERMUTE dst_shape[1] = -1; - auto result = std::make_shared( + res = std::make_shared( context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), false); - - return {result}; } else { // The input comes from a VIEW // Currently all cases are slicing at lowest dim @@ -43,13 +42,13 @@ OutputVector translate_cont(const NodeContext& context) { std::vector end = {(int64_t)src_shape[0], INT_MAX, split_addr + (int64_t)src_shape[2]}; std::vector strides = {1, 1, 1}; - auto begin_const = ov::op::v0::Constant::create(ov::element::i64, {begin.size()}, begin); + auto begin_const = ov::op::v0::Constant::create(element::i64, {begin.size()}, begin); auto end_const = ov::op::v0::Constant::create(ov::element::i64, {end.size()}, end); auto strides_const = ov::op::v0::Constant::create(ov::element::i64, {strides.size()}, strides); - auto slice = std::make_shared(context.get_input(0), begin_const, end_const, strides_const); - - return {slice}; + res = std::make_shared(context.get_input(0), begin_const, end_const, strides_const); } + + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/cpy.cpp b/ggml/src/ggml-openvino/openvino/op/cpy.cpp index 0c4a3d1558..7cdeddce38 100644 --- a/ggml/src/ggml-openvino/openvino/op/cpy.cpp +++ b/ggml/src/ggml-openvino/openvino/op/cpy.cpp @@ -33,6 +33,7 @@ OutputVector translate_cpy(const NodeContext& context) { auto src0 = context.get_input(0); auto src1 = context.get_input(1); auto past_token_len = context.get_input("past_token_len"); + ov::Output res; auto src0_shape = context.get_input_shape(0).to_shape(); auto output_shape = context.get_output_shape(0).to_shape(); @@ -63,8 +64,7 @@ OutputVector translate_cpy(const NodeContext& context) { indices, ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{1})); - auto res = std::make_shared(reshaped_src1, indices, src0); - return {res}; + res = std::make_shared(reshaped_src1, indices, src0); } else { // Write V to cache_v int64_t total_head_size = src0_shape[1]; @@ -99,10 +99,10 @@ OutputVector translate_cpy(const NodeContext& context) { ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{1, total_head_size, -1}), false); - auto res = std::make_shared(ov::OutputVector{src1_left, reshaped_src0, src1_right}, 2); - - return {res}; + res = std::make_shared(ov::OutputVector{src1_left, reshaped_src0, src1_right}, 2); } + + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp index 64fc57bd88..ca36548d9f 100644 --- a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +++ b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp @@ -30,7 +30,7 @@ OutputVector translate_get_rows(const NodeContext& context) { res = std::make_shared(res, context.get_output_type(0)); } - return {res}; + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/mul.cpp b/ggml/src/ggml-openvino/openvino/op/mul.cpp index 14473f4e27..40caf4331e 100644 --- a/ggml/src/ggml-openvino/openvino/op/mul.cpp +++ b/ggml/src/ggml-openvino/openvino/op/mul.cpp @@ -12,7 +12,7 @@ OutputVector translate_mul(const NodeContext& context) { num_inputs_check(context, 2, 2); auto res = std::make_shared(context.get_input(0), context.get_input(1)); - return {res}; + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp index 5673551f70..06e7d9ece0 100644 --- a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +++ b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp @@ -25,12 +25,13 @@ OutputVector translate_mulmat(const NodeContext& context) { int op_case = context.get_op_case(); FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported MULMAT case"); + ov::Output res; + if (op_case == 1) { auto src0 = context.get_input(0); auto src1 = std::make_shared(context.get_input(1), context.get_input_type(0)); auto result_lp = std::make_shared(src1, src0, false, true); - auto result = std::make_shared(result_lp, context.get_output_type(0)); - return {result}; + res = std::make_shared(result_lp, context.get_output_type(0)); } else { /* Two cases here: @@ -118,10 +119,10 @@ OutputVector translate_mulmat(const NodeContext& context) { } auto result_lp = std::make_shared(A, B, false, true); - auto result = std::make_shared(result_lp, context.get_output_type(0)); - - return {result}; + res = std::make_shared(result_lp, context.get_output_type(0)); } + + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/permute.cpp b/ggml/src/ggml-openvino/openvino/op/permute.cpp index 478c9430f0..649cf8f3e1 100644 --- a/ggml/src/ggml-openvino/openvino/op/permute.cpp +++ b/ggml/src/ggml-openvino/openvino/op/permute.cpp @@ -15,7 +15,7 @@ OutputVector translate_permute(const NodeContext& context) { auto perm = argsort_descend(context.get_output_stride(0)); auto res = std::make_shared(context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {3}, perm)); - return {res}; + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/reshape.cpp b/ggml/src/ggml-openvino/openvino/op/reshape.cpp index f6586d674c..49551eb815 100644 --- a/ggml/src/ggml-openvino/openvino/op/reshape.cpp +++ b/ggml/src/ggml-openvino/openvino/op/reshape.cpp @@ -37,8 +37,8 @@ OutputVector translate_reshape(const NodeContext& context) { {3}, std::vector{(int64_t)output_shape[0], -1, (int64_t)output_shape[2]}); } - Output res = std::make_shared(context.get_input(0), new_shape_node, false); - return {res}; + auto res = std::make_shared(context.get_input(0), new_shape_node, false); + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp index a91fffb72d..7b8b582dac 100644 --- a/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp @@ -39,7 +39,7 @@ OutputVector translate_rms_norm(const NodeContext& context) { auto res = std::make_shared(input_node, scale); - return {res}; + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index aad156082e..94810e549d 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -52,6 +52,8 @@ void ggml_rope_yarn_corr_dims(int n_dims, OutputVector translate_rope(const NodeContext& context) { num_inputs_check(context, 2, 3); + ov::Output res; + auto data_node = context.get_input(0); auto pos_node = context.get_input(1); pos_node = std::make_shared(pos_node, ov::element::f32); @@ -141,9 +143,7 @@ OutputVector translate_rope(const NodeContext& context) { ov::element::i64, Shape{3}, std::vector{-1, input_shape[1].get_length(), input_shape[2].get_length()}); - auto reshaped = std::make_shared(stack, shape_const, false); - - return {reshaped}; + res = std::make_shared(stack, shape_const, false); } else { auto slice_node = std::make_shared(data_node, @@ -160,9 +160,10 @@ OutputVector translate_rope(const NodeContext& context) { std::make_shared(slice_data_node_0, sin_theta_node), std::make_shared(slice_data_node_1, cos_theta_node)); - auto res_node = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, 2); - return {res_node}; + res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, 2); } + + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/scale.cpp b/ggml/src/ggml-openvino/openvino/op/scale.cpp index b393dd8aa2..8f0999432c 100644 --- a/ggml/src/ggml-openvino/openvino/op/scale.cpp +++ b/ggml/src/ggml-openvino/openvino/op/scale.cpp @@ -19,7 +19,7 @@ OutputVector translate_scale(const NodeContext& context) { auto res = std::make_shared(context.get_input(0), scale_node); - return {res}; + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/soft_max.cpp b/ggml/src/ggml-openvino/openvino/op/soft_max.cpp index 549c35a9b6..bb6b002395 100644 --- a/ggml/src/ggml-openvino/openvino/op/soft_max.cpp +++ b/ggml/src/ggml-openvino/openvino/op/soft_max.cpp @@ -24,6 +24,7 @@ OutputVector translate_soft_max(const NodeContext& context) { num_inputs_check(context, 1, 2); auto input_node = context.get_input(0); + ov::Output res; float scale = 1.0f; float max_bias = 0.0f; @@ -56,13 +57,13 @@ OutputVector translate_soft_max(const NodeContext& context) { } // Stride slice mask node - Output mask_begin_node = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {0, 0, 0}); + Output slice_start = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {0, 0, 0}); auto one = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); - auto input_last_two_dim = get_dimensions(input_node.get_node_shared_ptr(), {1, 2}); - auto mask_slice_shape = std::make_shared(ov::NodeVector{one, input_last_two_dim}, 0); - Output mask_stride_node = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {1, 1, 1}); - auto mask_node_sliced = - std::make_shared(mask_node, mask_begin_node, mask_slice_shape, mask_stride_node); + auto token_len = get_dimensions(input_node.get_node_shared_ptr(), {1}); + auto total_token_len = get_dimensions(mask_node.get_node_shared_ptr(), {2}); + auto slice_end = std::make_shared(ov::NodeVector{one, token_len, total_token_len}, 0); + Output slice_stride = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {1, 1, 1}); + auto mask_node_sliced = std::make_shared(mask_node, slice_start, slice_end, slice_stride); // slope * mask auto slope_node = @@ -73,13 +74,13 @@ OutputVector translate_soft_max(const NodeContext& context) { auto input_slope_mask_node = std::make_shared(input_node, slope_mask_node); // Calculate softmax - auto res = std::make_shared(input_slope_mask_node, 2); - return {res}; + res = std::make_shared(input_slope_mask_node, 2); } else { // Directly softmax - auto res = std::make_shared(input_node, 0); - return {res}; + res = std::make_shared(input_node, 0); } + + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/transpose.cpp b/ggml/src/ggml-openvino/openvino/op/transpose.cpp index 7d33ca9d61..99178a1944 100644 --- a/ggml/src/ggml-openvino/openvino/op/transpose.cpp +++ b/ggml/src/ggml-openvino/openvino/op/transpose.cpp @@ -14,7 +14,7 @@ OutputVector translate_transpose(const NodeContext& context) { auto perm = argsort_descend(context.get_output_stride(0)); auto res = std::make_shared(context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {3}, perm)); - return {res}; + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp index 1c396e6aaf..6c73653ca4 100644 --- a/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +++ b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp @@ -17,7 +17,7 @@ OutputVector translate_unary_silu(const NodeContext& context) { auto sigmoid = std::make_shared(input); auto res = std::make_shared(input, sigmoid); - return {res}; + return rename_outputs_with_suffix({res}, context.get_name()); } } // namespace op diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 910a0d8336..8eda23c1c5 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -1,5 +1,8 @@ #include "translate_session.hpp" +#include +#include + #include "input_model.hpp" namespace ov { @@ -91,11 +94,18 @@ std::shared_ptr TranslateSession::translate_graph(const frontend::InputMo used_params.push_back(param); } } - if (auto diff = params.size() - used_params.size()) { - std::cout << diff << " parameters are not used in the model." << std::endl; + if (getenv("GGML_OPENVINO_PROFILING")) { + if (auto diff = params.size() - used_params.size()) { + std::cout << diff << " parameters are not used in the model." << std::endl; + } } resulting_model = std::make_shared(results, used_params); + ov::pass::Manager manager; + manager.set_per_pass_validation(true); + manager.register_pass(); + manager.run_passes(resulting_model); + return resulting_model; } diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp index ff16e9d4ae..69e26f05ca 100644 --- a/ggml/src/ggml-openvino/openvino/utils.cpp +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -47,6 +47,17 @@ std::shared_ptr get_dimensions(const std::shared_ptr& node, return get_dimensions(std::make_shared(node), dims); } +OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std::string& suffix) { + for (const auto& output : outputs) { + auto node = output.get_node_shared_ptr(); + std::string name = node->get_friendly_name(); + name += "_"; + name += suffix; + node->set_friendly_name(name); + } + return outputs; +} + } // namespace ggml } // namespace frontend } // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/utils.hpp b/ggml/src/ggml-openvino/openvino/utils.hpp index 6e106fa932..e0fe250789 100644 --- a/ggml/src/ggml-openvino/openvino/utils.hpp +++ b/ggml/src/ggml-openvino/openvino/utils.hpp @@ -55,6 +55,8 @@ std::vector permute(const std::vector& x, const std::vector& perm) { std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims); std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims); +OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std::string& suffix); + namespace op { template OutputVector translate_1to1_match_2_inputs(const NodeContext& context) {