Replace Concat with Broadcast in MulMat for GQA
This commit is contained in:
parent
ebc4fc9f95
commit
bf5414c95e
|
|
@ -118,6 +118,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node) {
|
|||
}
|
||||
auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(src));
|
||||
param_node->set_friendly_name(src_name);
|
||||
param_node->output(0).get_tensor().set_names({src_name});
|
||||
m_model_inputs[src_name] = param_node;
|
||||
}
|
||||
}
|
||||
|
|
@ -262,6 +263,7 @@ void GgmlOvDecoder::add_extra_inputs() {
|
|||
std::string name = "past_token_len";
|
||||
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
|
||||
param_node->set_friendly_name(name);
|
||||
param_node->output(0).get_tensor().set_names({name});
|
||||
m_model_extra_inputs[name] = param_node;
|
||||
|
||||
auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{1});
|
||||
|
|
@ -280,6 +282,7 @@ void GgmlOvDecoder::add_extra_inputs() {
|
|||
std::string name = "attention_size";
|
||||
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
|
||||
param_node->set_friendly_name(name);
|
||||
param_node->output(0).get_tensor().set_names({name});
|
||||
m_model_extra_inputs[name] = param_node;
|
||||
|
||||
auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{1});
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <memory>
|
||||
#include <openvino/core/node.hpp>
|
||||
#include <openvino/core/node_output.hpp>
|
||||
#include <openvino/op/broadcast.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/convert.hpp>
|
||||
|
|
@ -10,6 +11,7 @@
|
|||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/transpose.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "../node_context.hpp"
|
||||
|
|
@ -45,16 +47,20 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
|||
auto num_heads_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{num_heads});
|
||||
auto num_heads_kv_node =
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{num_heads_kv});
|
||||
auto factor_node =
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_num_heads_factor});
|
||||
auto B_shape_last_two = get_dimensions(B.get_node_shared_ptr(), {1, 2});
|
||||
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
std::shared_ptr<ov::Node> new_B_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{num_heads_kv_node, one, B_shape_last_two}, 0);
|
||||
B = std::make_shared<ov::op::v1::Reshape>(B, new_B_shape, false);
|
||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
|
||||
auto B_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(B, unsqueeze_axes);
|
||||
|
||||
B = std::make_shared<ov::op::v0::Concat>(ov::OutputVector(kv_num_heads_factor, B), 1);
|
||||
new_B_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{num_heads_node, B_shape_last_two}, 0);
|
||||
B = std::make_shared<ov::op::v1::Reshape>(B, new_B_shape, false);
|
||||
auto broadcast_shape = std::make_shared<ov::op::v0::Concat>(
|
||||
ov::OutputVector{num_heads_kv_node, factor_node, B_shape_last_two}, 0);
|
||||
auto B_broadcasted = std::make_shared<ov::op::v3::Broadcast>(B_unsqueezed, broadcast_shape);
|
||||
|
||||
auto new_B_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{num_heads_node, B_shape_last_two}, 0);
|
||||
B = std::make_shared<ov::op::v1::Reshape>(B_broadcasted, new_B_shape, false);
|
||||
}
|
||||
|
||||
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
|
||||
|
|
|
|||
Loading…
Reference in New Issue