Simpilfy translation of get_rows

This commit is contained in:
Yu, Zijun 2025-09-02 14:53:09 +08:00 committed by Mustafa Cavus
parent c5231a2448
commit 810eb480f5
1 changed files with 8 additions and 18 deletions

View File

@ -3,10 +3,7 @@
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/squeeze.hpp>
#include <openvino/op/unsqueeze.hpp>
#include "../node_context.hpp"
#include "../op_table.hpp"
@ -31,22 +28,15 @@ OutputVector translate_get_rows(const NodeContext& context) {
indices = process_view_input(context, 1);
}
Output<Node> axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
if (indices.get_partial_shape()[1].get_length() == 1) {
indices =
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
if (data.get_partial_shape().rank() == 2) {
axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
}
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
if (data.get_partial_shape().rank() == 2) {
res =
std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
}
} else {
indices =
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
// data[b,x,y] ind[1,b,x'] test-backend-ops case
// data[x,y] ind[1,1,x'] normal case
indices = std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
if (data.get_partial_shape().rank() == 3) {
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1);
} else {
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
}
if (res.get_element_type() != context.get_output_type(0)) {