Replace Concat with Broadcast in MulMat for GQA

This commit is contained in:
Yu, Zijun 2025-07-04 14:38:15 +08:00 committed by Mustafa Cavus
parent ebc4fc9f95
commit bf5414c95e
2 changed files with 16 additions and 7 deletions

View File

@ -118,6 +118,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node) {
}
auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(src));
param_node->set_friendly_name(src_name);
param_node->output(0).get_tensor().set_names({src_name});
m_model_inputs[src_name] = param_node;
}
}
@ -262,6 +263,7 @@ void GgmlOvDecoder::add_extra_inputs() {
std::string name = "past_token_len";
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
param_node->set_friendly_name(name);
param_node->output(0).get_tensor().set_names({name});
m_model_extra_inputs[name] = param_node;
auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{1});
@ -280,6 +282,7 @@ void GgmlOvDecoder::add_extra_inputs() {
std::string name = "attention_size";
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
param_node->set_friendly_name(name);
param_node->output(0).get_tensor().set_names({name});
m_model_extra_inputs[name] = param_node;
auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{1});

View File

@ -3,6 +3,7 @@
#include <memory>
#include <openvino/core/node.hpp>
#include <openvino/core/node_output.hpp>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
@ -10,6 +11,7 @@
#include <openvino/op/reshape.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <vector>
#include "../node_context.hpp"
@ -45,16 +47,20 @@ OutputVector translate_mulmat(const NodeContext& context) {
auto num_heads_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{num_heads});
auto num_heads_kv_node =
ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{num_heads_kv});
auto factor_node =
ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_num_heads_factor});
auto B_shape_last_two = get_dimensions(B.get_node_shared_ptr(), {1, 2});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
std::shared_ptr<ov::Node> new_B_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{num_heads_kv_node, one, B_shape_last_two}, 0);
B = std::make_shared<ov::op::v1::Reshape>(B, new_B_shape, false);
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
auto B_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(B, unsqueeze_axes);
B = std::make_shared<ov::op::v0::Concat>(ov::OutputVector(kv_num_heads_factor, B), 1);
new_B_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{num_heads_node, B_shape_last_two}, 0);
B = std::make_shared<ov::op::v1::Reshape>(B, new_B_shape, false);
auto broadcast_shape = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{num_heads_kv_node, factor_node, B_shape_last_two}, 0);
auto B_broadcasted = std::make_shared<ov::op::v3::Broadcast>(B_unsqueezed, broadcast_shape);
auto new_B_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{num_heads_node, B_shape_last_two}, 0);
B = std::make_shared<ov::op::v1::Reshape>(B_broadcasted, new_B_shape, false);
}
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);