Revert changes in fuse_to_sdpa

This commit is contained in:
Yu, Zijun 2025-07-30 22:55:41 +08:00 committed by Mustafa Cavus
parent 1a19566b23
commit 43489bbfaa
3 changed files with 8 additions and 15 deletions

View File

@ -53,13 +53,7 @@ OutputVector translate_soft_max(const NodeContext& context) {
auto mask_node = context.get_input(1);
std::shared_ptr<ov::Node> token_len = get_dimensions(input_node, {1});
// Try using Q-cur to retrieve the token length, so that the translation of SOFT_MAX
// does not depend on the result of the QK MatMul, so that QK matmul + softmax + qkv matmul
// can be fused into SDPA.
if (input_node->get_type_info() == ov::op::v0::MatMul::get_type_info_static()) {
token_len = get_dimensions(input_node->get_input_node_shared_ptr(0), {1});
}
auto token_len = context.get_input("token_len");
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
std::shared_ptr<ov::Node> mask_node_sliced =

View File

@ -10,7 +10,6 @@
#include <openvino/op/softmax.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/pass/pattern/op/label.hpp>
#include <openvino/pass/pattern/op/optional.hpp>
#include <openvino/pass/pattern/op/pattern.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
@ -23,13 +22,15 @@ FuseToSDPA::FuseToSDPA() {
const auto m_k = ov::pass::pattern::any_input();
const auto m_q = ov::pass::pattern::any_input();
const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});
const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});
const auto m_scale = ov::pass::pattern::any_input();
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk, m_scale});
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32, m_scale});
const auto m_mask = ov::pass::pattern::any_input();
const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});
const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});
const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});
const auto m_v = ov::pass::pattern::any_input();
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk, m_v});
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16, m_v});
const auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
@ -41,7 +42,9 @@ FuseToSDPA::FuseToSDPA() {
auto v_trans =
register_new_node<ov::op::v1::Transpose>(v, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1}));
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask, scale, false);
auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16);
auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16);
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask_f16, scale_f16, false);
ov::replace_node(m.get_match_root(), sdpa);
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);

View File

@ -22,7 +22,6 @@
#include <openvino/op/unsqueeze.hpp>
#include <openvino/pass/constant_folding.hpp>
#include <openvino/pass/make_stateful.hpp>
#include <openvino/core/preprocess/pre_post_process.hpp>
#include "ggml-openvino/openvino/node_context.hpp"
#include "ggml-openvino/openvino/utils.hpp"
@ -269,12 +268,9 @@ std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<M
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
}
// SDPA is even worse on performance
manager.register_pass<pass::FuseToSDPA>();
manager.run_passes(model);
}
auto preprocessor = ov::preprocess::PrePostProcessor(model);
model = preprocessor.build();
return model;
}