Revert changes in fuse_to_sdpa
This commit is contained in:
parent
1a19566b23
commit
43489bbfaa
|
|
@ -53,13 +53,7 @@ OutputVector translate_soft_max(const NodeContext& context) {
|
|||
|
||||
auto mask_node = context.get_input(1);
|
||||
|
||||
std::shared_ptr<ov::Node> token_len = get_dimensions(input_node, {1});
|
||||
// Try using Q-cur to retrieve the token length, so that the translation of SOFT_MAX
|
||||
// does not depend on the result of the QK MatMul, so that QK matmul + softmax + qkv matmul
|
||||
// can be fused into SDPA.
|
||||
if (input_node->get_type_info() == ov::op::v0::MatMul::get_type_info_static()) {
|
||||
token_len = get_dimensions(input_node->get_input_node_shared_ptr(0), {1});
|
||||
}
|
||||
auto token_len = context.get_input("token_len");
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
std::shared_ptr<ov::Node> mask_node_sliced =
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@
|
|||
#include <openvino/op/softmax.hpp>
|
||||
#include <openvino/op/transpose.hpp>
|
||||
#include <openvino/pass/pattern/op/label.hpp>
|
||||
#include <openvino/pass/pattern/op/optional.hpp>
|
||||
#include <openvino/pass/pattern/op/pattern.hpp>
|
||||
#include <openvino/pass/pattern/op/wrap_type.hpp>
|
||||
|
||||
|
|
@ -23,13 +22,15 @@ FuseToSDPA::FuseToSDPA() {
|
|||
const auto m_k = ov::pass::pattern::any_input();
|
||||
const auto m_q = ov::pass::pattern::any_input();
|
||||
const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});
|
||||
const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});
|
||||
const auto m_scale = ov::pass::pattern::any_input();
|
||||
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk, m_scale});
|
||||
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32, m_scale});
|
||||
const auto m_mask = ov::pass::pattern::any_input();
|
||||
const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});
|
||||
const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});
|
||||
const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});
|
||||
const auto m_v = ov::pass::pattern::any_input();
|
||||
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk, m_v});
|
||||
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16, m_v});
|
||||
|
||||
const auto callback = [=](ov::pass::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
|
@ -41,7 +42,9 @@ FuseToSDPA::FuseToSDPA() {
|
|||
|
||||
auto v_trans =
|
||||
register_new_node<ov::op::v1::Transpose>(v, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1}));
|
||||
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask, scale, false);
|
||||
auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16);
|
||||
auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16);
|
||||
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask_f16, scale_f16, false);
|
||||
|
||||
ov::replace_node(m.get_match_root(), sdpa);
|
||||
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@
|
|||
#include <openvino/op/unsqueeze.hpp>
|
||||
#include <openvino/pass/constant_folding.hpp>
|
||||
#include <openvino/pass/make_stateful.hpp>
|
||||
#include <openvino/core/preprocess/pre_post_process.hpp>
|
||||
|
||||
#include "ggml-openvino/openvino/node_context.hpp"
|
||||
#include "ggml-openvino/openvino/utils.hpp"
|
||||
|
|
@ -269,12 +268,9 @@ std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<M
|
|||
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
|
||||
}
|
||||
|
||||
// SDPA is even worse on performance
|
||||
manager.register_pass<pass::FuseToSDPA>();
|
||||
manager.run_passes(model);
|
||||
}
|
||||
auto preprocessor = ov::preprocess::PrePostProcessor(model);
|
||||
model = preprocessor.build();
|
||||
return model;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue