Simpilfy translation of get_rows
This commit is contained in:
parent
c5231a2448
commit
810eb480f5
|
|
@ -3,10 +3,7 @@
|
|||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/gather.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/squeeze.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
|
||||
#include "../node_context.hpp"
|
||||
#include "../op_table.hpp"
|
||||
|
|
@ -31,22 +28,15 @@ OutputVector translate_get_rows(const NodeContext& context) {
|
|||
indices = process_view_input(context, 1);
|
||||
}
|
||||
|
||||
Output<Node> axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
|
||||
if (indices.get_partial_shape()[1].get_length() == 1) {
|
||||
indices =
|
||||
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||
if (data.get_partial_shape().rank() == 2) {
|
||||
axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
|
||||
}
|
||||
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
|
||||
if (data.get_partial_shape().rank() == 2) {
|
||||
res =
|
||||
std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
}
|
||||
} else {
|
||||
indices =
|
||||
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
// data[b,x,y] ind[1,b,x'] test-backend-ops case
|
||||
// data[x,y] ind[1,1,x'] normal case
|
||||
indices = std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
if (data.get_partial_shape().rank() == 3) {
|
||||
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
|
||||
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1);
|
||||
} else {
|
||||
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
|
||||
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
|
||||
}
|
||||
|
||||
if (res.get_element_type() != context.get_output_type(0)) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue