mulmat input conversion fix

This commit is contained in:
Cavus Mustafa 2025-07-29 17:55:15 -07:00 committed by Mustafa Cavus
parent 01cdf4a9cc
commit e2fdc1b988
1 changed files with 4 additions and 1 deletions

View File

@ -12,6 +12,7 @@
#include <openvino/op/slice.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <openvino/op/util/op_types.hpp>
#include <vector>
#include "../node_context.hpp"
@ -29,8 +30,10 @@ OutputVector translate_mulmat(const NodeContext& context) {
ov::Output<Node> res;
ov::Output<ov::Node> B = context.get_input(0);
ov::Output<ov::Node> A = context.get_input(1);
if (context.get_input_type(0) != context.get_input_type(1)) {
if (ov::op::util::is_constant(B.get_node()) && context.get_input_type(0) != context.get_input_type(1)) {
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_input_type(1));
} else if (context.get_input_type(0) != context.get_input_type(1)) {
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
}
auto B_shape = context.get_input_shape(0).to_shape();