From 810eb480f529148ee6e20437755dbb3273589f60 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Tue, 2 Sep 2025 14:53:09 +0800 Subject: [PATCH] Simpilfy translation of get_rows --- .../ggml-openvino/openvino/op/get_rows.cpp | 26 ++++++------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp index 0de77da59f..5e4c7d901a 100644 --- a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +++ b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp @@ -3,10 +3,7 @@ #include #include #include -#include -#include #include -#include #include "../node_context.hpp" #include "../op_table.hpp" @@ -31,22 +28,15 @@ OutputVector translate_get_rows(const NodeContext& context) { indices = process_view_input(context, 1); } - Output axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); - if (indices.get_partial_shape()[1].get_length() == 1) { - indices = - std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); - if (data.get_partial_shape().rank() == 2) { - axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); - } - res = std::make_shared(data, indices, axis); - if (data.get_partial_shape().rank() == 2) { - res = - std::make_shared(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); - } - } else { - indices = - std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + // data[b,x,y] ind[1,b,x'] test-backend-ops case + // data[x,y] ind[1,1,x'] normal case + indices = std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + if (data.get_partial_shape().rank() == 3) { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); res = std::make_shared(data, indices, axis, 1); + } else { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + res = std::make_shared(data, indices, axis); } if (res.get_element_type() != context.get_output_type(0)) {