mulmat type conversion update
This commit is contained in:
parent
e2fdc1b988
commit
93b2d09a2d
|
|
@ -30,10 +30,13 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
||||||
ov::Output<Node> res;
|
ov::Output<Node> res;
|
||||||
ov::Output<ov::Node> B = context.get_input(0);
|
ov::Output<ov::Node> B = context.get_input(0);
|
||||||
ov::Output<ov::Node> A = context.get_input(1);
|
ov::Output<ov::Node> A = context.get_input(1);
|
||||||
|
|
||||||
|
bool convert_out_type = false;
|
||||||
if (ov::op::util::is_constant(B.get_node()) && context.get_input_type(0) != context.get_input_type(1)) {
|
if (ov::op::util::is_constant(B.get_node()) && context.get_input_type(0) != context.get_input_type(1)) {
|
||||||
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_input_type(1));
|
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_input_type(1));
|
||||||
} else if (context.get_input_type(0) != context.get_input_type(1)) {
|
} else if (context.get_input_type(0) != context.get_input_type(1)) {
|
||||||
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
|
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
|
||||||
|
convert_out_type = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto B_shape = context.get_input_shape(0).to_shape();
|
auto B_shape = context.get_input_shape(0).to_shape();
|
||||||
|
|
@ -68,7 +71,12 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
||||||
A = Z;
|
A = Z;
|
||||||
}
|
}
|
||||||
|
|
||||||
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
|
if (convert_out_type) {
|
||||||
|
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
|
||||||
|
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
|
||||||
|
} else {
|
||||||
|
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
|
||||||
|
}
|
||||||
|
|
||||||
return rename_outputs_with_suffix({res}, context.get_name());
|
return rename_outputs_with_suffix({res}, context.get_name());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue