Simpilfy translation of get_rows

This commit is contained in:
Yu, Zijun 2025-09-02 14:53:09 +08:00 committed by Mustafa Cavus
parent c5231a2448
commit 810eb480f5
1 changed files with 8 additions and 18 deletions

View File

@ -3,10 +3,7 @@
#include <openvino/op/constant.hpp> #include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp> #include <openvino/op/convert.hpp>
#include <openvino/op/gather.hpp> #include <openvino/op/gather.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/squeeze.hpp> #include <openvino/op/squeeze.hpp>
#include <openvino/op/unsqueeze.hpp>
#include "../node_context.hpp" #include "../node_context.hpp"
#include "../op_table.hpp" #include "../op_table.hpp"
@ -31,22 +28,15 @@ OutputVector translate_get_rows(const NodeContext& context) {
indices = process_view_input(context, 1); indices = process_view_input(context, 1);
} }
Output<Node> axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); // data[b,x,y] ind[1,b,x'] test-backend-ops case
if (indices.get_partial_shape()[1].get_length() == 1) { // data[x,y] ind[1,1,x'] normal case
indices = indices = std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); if (data.get_partial_shape().rank() == 3) {
if (data.get_partial_shape().rank() == 2) { auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
}
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
if (data.get_partial_shape().rank() == 2) {
res =
std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
}
} else {
indices =
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1); res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1);
} else {
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
} }
if (res.get_element_type() != context.get_output_type(0)) { if (res.get_element_type() != context.get_output_type(0)) {