mulmat input conversion fix
This commit is contained in:
parent
01cdf4a9cc
commit
e2fdc1b988
|
|
@ -12,6 +12,7 @@
|
|||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/transpose.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
#include <openvino/op/util/op_types.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "../node_context.hpp"
|
||||
|
|
@ -29,8 +30,10 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
|||
ov::Output<Node> res;
|
||||
ov::Output<ov::Node> B = context.get_input(0);
|
||||
ov::Output<ov::Node> A = context.get_input(1);
|
||||
if (context.get_input_type(0) != context.get_input_type(1)) {
|
||||
if (ov::op::util::is_constant(B.get_node()) && context.get_input_type(0) != context.get_input_type(1)) {
|
||||
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_input_type(1));
|
||||
} else if (context.get_input_type(0) != context.get_input_type(1)) {
|
||||
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
|
||||
}
|
||||
|
||||
auto B_shape = context.get_input_shape(0).to_shape();
|
||||
|
|
|
|||
Loading…
Reference in New Issue