Fuse to SDPA

This commit is contained in:
Yu, Zijun 2025-07-03 13:22:39 +08:00 committed by Mustafa Cavus
parent 73ee84fffe
commit ebc4fc9f95
12 changed files with 189 additions and 93 deletions

View File

@ -26,27 +26,36 @@
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#include "ggml-backend.h" #include "ggml-backend.h"
GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* m_cgraph, bool is_static, GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token,
int context_size, int num_heads, int num_heads_kv, int head_size) :
GgmlOvDecoder::GgmlOvDecoder(node, cgraph, is_static, is_first_token) {
m_context_size = context_size;
m_num_heads = num_heads;
m_num_heads_kv = num_heads_kv;
m_head_size = head_size;
}
GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static,
bool is_first_token) : bool is_first_token) :
m_cgraph(m_cgraph), m_cgraph(cgraph),
m_node(node), m_node(node),
m_op_name(m_node ? std::string(m_node->name) : "NONE_OP"), m_op_name(m_node ? std::string(m_node->name) : "NONE_OP"),
m_is_static(is_static), m_is_static(is_static),
m_is_first_token(is_first_token) { m_is_first_token(is_first_token) {
// TODO avoid static
static std::map<std::string, std::shared_ptr<ov::Node>> model_weights; static std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
if (m_node) { if (m_node) {
set_input_output(m_node); set_input_output(m_node);
} else { } else {
static bool printed = false; static bool printed = false;
if (!printed && getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS")) { if (!printed && getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS")) {
print_tensor_address_map(m_cgraph); print_tensor_address_map(cgraph);
printed = true; printed = true;
} }
if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) {
std::string filename = "cgraph.txt"; std::string filename = "cgraph.txt";
dump_cgraph(m_cgraph, filename); dump_cgraph(cgraph, filename);
} }
set_llm_params(); set_llm_params();
@ -57,8 +66,8 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* m_cgr
weight_created = true; weight_created = true;
} }
for (int node_n = 0; node_n < m_cgraph->n_nodes; node_n++) { for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
auto* cur_node = m_cgraph->nodes[node_n]; auto* cur_node = cgraph->nodes[node_n];
m_nodes.push_back(cur_node); m_nodes.push_back(cur_node);
set_input_output(cur_node); set_input_output(cur_node);
} }
@ -195,7 +204,7 @@ void GgmlOvDecoder::set_llm_params() {
auto* node = m_cgraph->nodes[i]; auto* node = m_cgraph->nodes[i];
if (node->op == GGML_OP_VIEW && std::string(node->name) == "cache_k_l0 (view)") { if (node->op == GGML_OP_VIEW && std::string(node->name) == "cache_k_l0 (view)") {
auto* cache_k = node->src[0]; auto* cache_k = node->src[0];
m_max_token_len = cache_k->ne[1]; m_context_size = cache_k->ne[1];
} else if (node->op == GGML_OP_ROPE && std::string(node->name) == "Qcur-0") { } else if (node->op == GGML_OP_ROPE && std::string(node->name) == "Qcur-0") {
m_head_size = node->ne[0]; m_head_size = node->ne[0];
m_num_heads = node->ne[1]; m_num_heads = node->ne[1];
@ -210,30 +219,30 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
if (std::string(src->name) == "inp_tokens" || std::string(src->name) == "inp_pos") { if (std::string(src->name) == "inp_tokens" || std::string(src->name) == "inp_pos") {
if (m_is_static) { if (m_is_static) {
if (m_is_first_token) { if (m_is_first_token) {
input_shape = ov::PartialShape{ 1, 1, m_max_token_len }; input_shape = ov::PartialShape{1, 1, m_context_size};
} else { } else {
input_shape = ov::PartialShape{ 1, 1, 1 }; input_shape = ov::PartialShape{1, 1, 1};
} }
} else { } else {
input_shape = ov::PartialShape{ 1, 1, ov::Dimension(1, m_max_token_len) }; input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_context_size)};
} }
} else if (std::string(src->name) == "KQ_mask") { } else if (std::string(src->name) == "KQ_mask") {
if (m_is_static) { if (m_is_static) {
if (m_is_first_token) { if (m_is_first_token) {
input_shape = ov::PartialShape{ 1, m_max_token_len, m_max_token_len }; input_shape = ov::PartialShape{1, m_context_size, m_context_size};
} else { } else {
input_shape = ov::PartialShape{ 1, 1, m_max_token_len }; input_shape = ov::PartialShape{1, 1, m_context_size};
} }
} else { } else {
auto max_mask_size = GGML_PAD(m_max_token_len, GGML_KQ_MASK_PAD); auto max_mask_size = GGML_PAD(m_context_size, GGML_KQ_MASK_PAD);
input_shape = ov::PartialShape{ 1, ov::Dimension(1, max_mask_size), ov::Dimension(1, max_mask_size) }; input_shape = ov::PartialShape{1, ov::Dimension(1, max_mask_size), ov::Dimension(1, max_mask_size)};
} }
} else if (std::string(src->name).find("cache_k") == 0) { } else if (std::string(src->name).find("cache_k") == 0) {
input_shape = ov::PartialShape{ m_max_token_len, m_num_heads_kv, m_head_size }; input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
} else if (std::string(src->name).find("cache_v") == 0) { } else if (std::string(src->name).find("cache_v") == 0) {
input_shape = ov::PartialShape{ m_num_heads_kv, m_head_size, m_max_token_len }; input_shape = ov::PartialShape{m_num_heads_kv, m_head_size, m_context_size};
} else { } else {
input_shape = ov::PartialShape{ get_shape(src) }; input_shape = ov::PartialShape{get_shape(src)};
} }
return input_shape; return input_shape;
} }
@ -557,7 +566,8 @@ int32_t* GgmlOvDecoder::get_output_op_params(const std::string& name) const {
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const { void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
for (const auto& node : m_nodes) { for (const auto& node : m_nodes) {
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_is_first_token); auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_is_first_token, m_context_size,
m_num_heads, m_num_heads_kv, m_head_size);
node_visitor(decoder); node_visitor(decoder);
} }
} }

View File

@ -11,9 +11,9 @@
class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
public: public:
using ov::frontend::ggml::GgmlDecoder::GgmlDecoder;
GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token); GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token);
GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token,
int context_size, int num_heads, int num_heads_kv, int head_size);
virtual ov::Any get_attribute(const std::string& name) const override { virtual ov::Any get_attribute(const std::string& name) const override {
return nullptr; return nullptr;
@ -90,7 +90,7 @@ public:
return m_model_output_names; return m_model_output_names;
} }
virtual int get_max_token_len() const override { return m_max_token_len; } virtual int get_context_size() const override { return m_context_size; }
virtual int get_num_heads() const override { return m_num_heads; } virtual int get_num_heads() const override { return m_num_heads; }
@ -114,7 +114,7 @@ private:
static std::vector<size_t> get_stride(const ggml_tensor* tensor); static std::vector<size_t> get_stride(const ggml_tensor* tensor);
static ov::element::Type get_ov_type(const ggml_tensor* tensor); static ov::element::Type get_ov_type(const ggml_tensor* tensor);
// set max_token_len, num_heads, etc // set context_size, num_heads, etc
void set_llm_params(); void set_llm_params();
static std::shared_ptr<ov::Node> create_weight_node(ggml_tensor* tensor); static std::shared_ptr<ov::Node> create_weight_node(ggml_tensor* tensor);
@ -136,7 +136,7 @@ private:
std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values; std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values;
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights; std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
std::vector<std::string> m_model_output_names; std::vector<std::string> m_model_output_names;
int m_max_token_len; int m_context_size;
int m_num_heads; int m_num_heads;
int m_num_heads_kv; int m_num_heads_kv;
int m_head_size; int m_head_size;

View File

@ -65,7 +65,7 @@ public:
virtual bool is_static() const = 0; virtual bool is_static() const = 0;
virtual bool is_first_token() const = 0; virtual bool is_first_token() const = 0;
virtual int get_max_token_len() const = 0; virtual int get_context_size() const = 0;
}; };
} // namespace ggml } // namespace ggml

View File

@ -91,11 +91,16 @@ public:
bool is_first_token() const { bool is_first_token() const {
return m_decoder->is_first_token(); return m_decoder->is_first_token();
} }
int get_max_token_len() const {
return m_decoder->get_max_token_len();
}
private: int get_num_heads() const { return m_decoder->get_num_heads(); }
int get_num_heads_kv() const { return m_decoder->get_num_heads_kv(); }
int get_head_size() const { return m_decoder->get_head_size(); }
int get_context_size() const { return m_decoder->get_context_size(); }
private:
std::shared_ptr<GgmlDecoder> m_decoder; std::shared_ptr<GgmlDecoder> m_decoder;
std::shared_ptr<TensorMap>& m_tensor_map; std::shared_ptr<TensorMap>& m_tensor_map;
TranslateSession* m_translate_session; TranslateSession* m_translate_session;

View File

@ -38,9 +38,8 @@ OutputVector translate_mulmat(const NodeContext& context) {
ov::Output<ov::Node> B = context.get_input(0); ov::Output<ov::Node> B = context.get_input(0);
ov::Output<ov::Node> A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0)); ov::Output<ov::Node> A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
auto src0_shape = context.get_input_shape(0).to_shape(); int64_t num_heads = context.get_num_heads();
int64_t num_heads = context.get_input_shape(1).to_shape()[0]; int64_t num_heads_kv = context.get_num_heads_kv();
int64_t num_heads_kv = src0_shape[0];
int64_t kv_num_heads_factor = num_heads / num_heads_kv; int64_t kv_num_heads_factor = num_heads / num_heads_kv;
if (kv_num_heads_factor > 1) { if (kv_num_heads_factor > 1) {
auto num_heads_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{num_heads}); auto num_heads_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{num_heads});

View File

@ -27,7 +27,7 @@ OutputVector translate_permute(const NodeContext& context) {
if (op_case == 1) { if (op_case == 1) {
auto perm = argsort_descend(context.get_output_stride(0)); auto perm = argsort_descend(context.get_output_stride(0));
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0), res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
ov::op::v0::Constant::create(ov::element::i64, { 3 }, perm)); ov::op::v0::Constant::create(ov::element::i64, {3}, perm));
} else { } else {
auto src = context.get_input(0); auto src = context.get_input(0);
auto attention_size = context.get_input("attention_size"); auto attention_size = context.get_input("attention_size");
@ -51,19 +51,16 @@ OutputVector translate_permute(const NodeContext& context) {
false); false);
} }
auto slice_start = ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>(3, 0)); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto slice_step = ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>(3, 1)); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
std::shared_ptr<ov::Node> slice_end; auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
std::shared_ptr<ov::Node> slice_axis;
if (op_case == 2) { if (op_case == 2) {
slice_end = std::make_shared<ov::op::v0::Concat>( slice_axis = zero;
ov::OutputVector{attention_size, ov::op::v0::Constant::create(ov::element::i64, {2}, {src_shape[1], src_shape[2]})},
0);
} else { } else {
slice_end = std::make_shared<ov::op::v0::Concat>( slice_axis = two;
ov::OutputVector{ov::op::v0::Constant::create(ov::element::i64, {2}, {src_shape[1], src_shape[0]}), attention_size},
0);
} }
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, slice_start, slice_end, slice_step); auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, slice_axis);
if (op_case == 2) { if (op_case == 2) {
res = std::make_shared<ov::op::v1::Transpose>(src_slice, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); res = std::make_shared<ov::op::v1::Transpose>(src_slice, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
@ -71,7 +68,7 @@ OutputVector translate_permute(const NodeContext& context) {
res = src_slice; res = src_slice;
} }
} }
return rename_outputs_with_suffix({ res }, context.get_name()); return rename_outputs_with_suffix({res}, context.get_name());
} }
} // namespace op } // namespace op

View File

@ -1,3 +1,5 @@
#include <climits>
#include <cstdint>
#include <memory> #include <memory>
#include <openvino/core/node.hpp> #include <openvino/core/node.hpp>
#include <openvino/core/node_output.hpp> #include <openvino/core/node_output.hpp>
@ -5,6 +7,7 @@
#include <openvino/op/concat.hpp> #include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp> #include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp> #include <openvino/op/convert.hpp>
#include <openvino/op/matmul.hpp>
#include <openvino/op/multiply.hpp> #include <openvino/op/multiply.hpp>
#include <openvino/op/slice.hpp> #include <openvino/op/slice.hpp>
#include <openvino/op/softmax.hpp> #include <openvino/op/softmax.hpp>
@ -22,62 +25,61 @@ namespace op {
OutputVector translate_soft_max(const NodeContext& context) { OutputVector translate_soft_max(const NodeContext& context) {
num_inputs_check(context, 1, 2); num_inputs_check(context, 1, 2);
auto input_node = context.get_input(0); auto input_node = context.get_input(0).get_node_shared_ptr();
ov::Output<Node> res; ov::Output<Node> res;
float scale = 1.0f; float scale = 1.0f;
float max_bias = 0.0f; float max_bias = 0.0f;
auto * op_params = context.get_output_op_params(0); auto* op_params = context.get_output_op_params(0);
memcpy(&scale, (float*)op_params + 0, sizeof(float)); memcpy(&scale, (float*) op_params + 0, sizeof(float));
memcpy(&max_bias, (float*)op_params + 1, sizeof(float)); memcpy(&max_bias, (float*) op_params + 1, sizeof(float));
const uint32_t h = context.get_head_size();
// const uint32_t n_head = context.get_input_shape(0)[0].get_length(); const uint32_t n_head = context.get_input_shape(0)[0].get_length();
// const uint32_t n_head_log2 = 1u << (uint32_t)floor(log2(n_head)); const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
// const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias) / n_head_log2);
// const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) const float slope =
// : 1.0f; (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
const float slope = 1.0;
std::shared_ptr<ov::Node> scaled_input;
if (scale != 1.0f) { if (scale != 1.0f) {
auto scale_node = auto scale_node =
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale}); std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale});
input_node = std::make_shared<ov::op::v1::Multiply>(input_node, scale_node); scaled_input = std::make_shared<ov::op::v1::Multiply>(input_node, scale_node);
} }
if (context.get_input_size() == 2) { auto mask_node = context.get_input(1);
// Calculate mask then softmax
auto mask_node = context.get_input(1);
ov::element::Type mask_type = context.get_input_type(1);
if (mask_type == ov::element::f16) {
// Convert f16 to f32
mask_node = std::make_shared<ov::op::v0::Convert>(mask_node, ov::element::f32);
}
// Stride slice mask node // Use Q-cur to retrieve the token length, so that the translation of SOFT_MAX
Output<Node> slice_start = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {0, 0, 0}); // does not depend on the result of the QK MatMul, so that QK matmul + softmax + qkv matmul
auto one = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); // can be fused into SDPA.
auto token_len = get_dimensions(input_node.get_node_shared_ptr(), {1}); if (input_node->get_type_info() != ov::op::v0::Convert::get_type_info_static()) {
auto total_token_len = get_dimensions(mask_node.get_node_shared_ptr(), {2}); throw std::runtime_error("Input of SOFT_MAX should be MatMul of qk followed by a Convert");
auto slice_end = std::make_shared<ov::op::v0::Concat>(ov::NodeVector{one, token_len, total_token_len}, 0); }
Output<Node> slice_stride = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {1, 1, 1}); auto qk = input_node->get_input_node_shared_ptr(0);
auto mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, slice_start, slice_end, slice_stride); if (qk->get_type_info() != ov::op::v0::MatMul::get_type_info_static()) {
throw std::runtime_error("Input of SOFT_MAX should be MatMul of qk followed by a Convert");
}
auto token_len = get_dimensions(qk->get_input_node_shared_ptr(0), {1});
// slope * mask auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
Output<Node> slope_mask;
if (slope != 1.0f) {
auto slope_node = auto slope_node =
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{slope}); std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{slope});
auto slope_mask_node = std::make_shared<ov::op::v1::Multiply>(mask_node_sliced, slope_node); slope_mask = std::make_shared<ov::op::v1::Multiply>(mask_node_sliced, slope_node);
throw std::runtime_error("Slope != 1.0f in softmax has not been tested, verify it before use.");
// input + slope * mask
auto input_slope_mask_node = std::make_shared<ov::op::v1::Add>(input_node, slope_mask_node);
// Calculate softmax
res = std::make_shared<ov::op::v8::Softmax>(input_slope_mask_node, 2);
} else {
// Directly softmax
res = std::make_shared<ov::op::v8::Softmax>(input_node, 0);
} }
slope_mask = mask_node_sliced;
auto input_slope_mask_node = std::make_shared<ov::op::v1::Add>(scaled_input, slope_mask);
res = std::make_shared<ov::op::v8::Softmax>(input_slope_mask_node, 2);
return rename_outputs_with_suffix({res}, context.get_name()); return rename_outputs_with_suffix({res}, context.get_name());
} }

View File

@ -0,0 +1,61 @@
#include "fuse_to_sdpa.hpp"
#include <openvino/core/graph_util.hpp>
#include <openvino/core/rt_info.hpp>
#include <openvino/op/add.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/matmul.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/scaled_dot_product_attention.hpp>
#include <openvino/op/softmax.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/pass/pattern/op/label.hpp>
#include <openvino/pass/pattern/op/pattern.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
namespace ov {
namespace frontend {
namespace ggml {
namespace pass {
FuseToSDPA::FuseToSDPA() {
const auto m_k = ov::pass::pattern::any_input();
const auto m_q = ov::pass::pattern::any_input();
const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});
const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});
const auto m_scale = ov::pass::pattern::any_input();
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32, m_scale});
const auto m_mask = ov::pass::pattern::any_input();
const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});
const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});
const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});
const auto m_v = ov::pass::pattern::any_input();
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16, m_v});
const auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto k = pattern_to_output[m_k];
auto q = pattern_to_output[m_q];
auto v = pattern_to_output[m_v];
auto mask = pattern_to_output[m_mask];
auto scale = pattern_to_output[m_scale];
auto v_trans =
register_new_node<ov::op::v1::Transpose>(v, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1}));
auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16);
auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16);
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask_f16, scale_f16, false);
ov::replace_node(m.get_match_root(), sdpa);
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);
return true;
};
register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_qkv, "ov::frontend::ggml::pass::FuseToSDPA"),
callback);
}
} // namespace pass
} // namespace ggml
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,17 @@
#include "openvino/pass/matcher_pass.hpp"
namespace ov {
namespace frontend {
namespace ggml {
namespace pass {
class FuseToSDPA : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::FuseToSDPA")
FuseToSDPA();
};
} // namespace pass
} // namespace ggml
} // namespace frontend
} // namespace ov

View File

@ -9,6 +9,7 @@
#include <openvino/pass/make_stateful.hpp> #include <openvino/pass/make_stateful.hpp>
#include "input_model.hpp" #include "input_model.hpp"
#include "pass/fuse_to_sdpa.hpp"
namespace ov { namespace ov {
namespace frontend { namespace frontend {
@ -145,6 +146,8 @@ void TranslateSession::apply_transformations(const std::shared_ptr<Model>& model
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names(); const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names); const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs); manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
manager.register_pass<pass::FuseToSDPA>();
} }
manager.run_passes(model); manager.run_passes(model);

View File

@ -65,7 +65,7 @@ template <typename T>
OutputVector translate_1to1_match_2_inputs(const NodeContext& context) { OutputVector translate_1to1_match_2_inputs(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto res = std::make_shared<T>(context.get_input(0), context.get_input(1)); auto res = std::make_shared<T>(context.get_input(0), context.get_input(1));
return rename_outputs_with_suffix({ res }, context.get_name()); return rename_outputs_with_suffix({res}, context.get_name());
} }
} // namespace op } // namespace op

View File

@ -17,6 +17,7 @@
#include <openvino/runtime/compiled_model.hpp> #include <openvino/runtime/compiled_model.hpp>
#include <openvino/runtime/infer_request.hpp> #include <openvino/runtime/infer_request.hpp>
#include <openvino/runtime/intel_npu/properties.hpp> #include <openvino/runtime/intel_npu/properties.hpp>
#include <openvino/runtime/properties.hpp>
#include <openvino/runtime/tensor.hpp> #include <openvino/runtime/tensor.hpp>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
@ -88,6 +89,7 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
if (cache_dir && !is_static) { if (cache_dir && !is_static) {
core.set_property(ov::cache_dir(cache_dir)); core.set_property(ov::cache_dir(cache_dir));
} }
// core.set_property(ov::enable_profiling(true));
static std::unordered_map<struct ggml_cgraph*, std::shared_ptr<ov::InferRequest>> infer_request_cache; static std::unordered_map<struct ggml_cgraph*, std::shared_ptr<ov::InferRequest>> infer_request_cache;
static std::unordered_map<struct ggml_cgraph*, std::vector<std::string>> ov_input_names_cache; static std::unordered_map<struct ggml_cgraph*, std::vector<std::string>> ov_input_names_cache;
@ -256,10 +258,10 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
} else { } else {
if (param_name == "inp_tokens" || param_name == "inp_pos") { if (param_name == "inp_tokens" || param_name == "inp_pos") {
if (is_first_token) { if (is_first_token) {
size_t max_token_len = ggml_decoder->get_max_token_len(); size_t context_size = ggml_decoder->get_context_size();
const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name); const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name);
std::vector<int32_t> padded_data = pad_input<int32_t>(input_tensor_ggml, 1, max_token_len, 0); std::vector<int32_t> padded_data = pad_input<int32_t>(input_tensor_ggml, 1, context_size, 0);
input_tensor = ov::Tensor(ov::element::i32, ov::Shape{ 1, 1, max_token_len }); input_tensor = ov::Tensor(ov::element::i32, ov::Shape{1, 1, context_size});
auto* data_ptr = input_tensor.data<int32_t>(); auto* data_ptr = input_tensor.data<int32_t>();
std::copy(padded_data.begin(), padded_data.end(), data_ptr); std::copy(padded_data.begin(), padded_data.end(), data_ptr);
} else { } else {
@ -267,18 +269,18 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
} }
} else if (param_name == "KQ_mask") { } else if (param_name == "KQ_mask") {
size_t max_token_len = ggml_decoder->get_max_token_len(); size_t context_size = ggml_decoder->get_context_size();
const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name); const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name);
if (is_first_token) { if (is_first_token) {
std::vector<float> padded_data = std::vector<float> padded_data =
pad_input<float>(input_tensor_ggml, max_token_len, max_token_len, -INFINITY); pad_input<float>(input_tensor_ggml, context_size, context_size, -INFINITY);
set_zero_diagonal(padded_data, max_token_len); set_zero_diagonal(padded_data, context_size);
input_tensor = ov::Tensor(ov::element::f32, ov::Shape{ 1, max_token_len, max_token_len }); input_tensor = ov::Tensor(ov::element::f32, ov::Shape{1, context_size, context_size});
auto* data_ptr = input_tensor.data<float>(); auto* data_ptr = input_tensor.data<float>();
std::copy(padded_data.begin(), padded_data.end(), data_ptr); std::copy(padded_data.begin(), padded_data.end(), data_ptr);
} else { } else {
std::vector<float> padded_data = pad_input<float>(input_tensor_ggml, 1, max_token_len, -INFINITY); std::vector<float> padded_data = pad_input<float>(input_tensor_ggml, 1, context_size, -INFINITY);
input_tensor = ov::Tensor(ov::element::f32, ov::Shape{ 1, 1, max_token_len }); input_tensor = ov::Tensor(ov::element::f32, ov::Shape{1, 1, context_size});
auto* data_ptr = input_tensor.data<float>(); auto* data_ptr = input_tensor.data<float>();
std::copy(padded_data.begin(), padded_data.end(), data_ptr); std::copy(padded_data.begin(), padded_data.end(), data_ptr);
} }