diff --git a/ggml/src/ggml-openvino/openvino/op/soft_max.cpp b/ggml/src/ggml-openvino/openvino/op/soft_max.cpp index 401acaf865..046cb93c8b 100644 --- a/ggml/src/ggml-openvino/openvino/op/soft_max.cpp +++ b/ggml/src/ggml-openvino/openvino/op/soft_max.cpp @@ -53,13 +53,7 @@ OutputVector translate_soft_max(const NodeContext& context) { auto mask_node = context.get_input(1); - std::shared_ptr token_len = get_dimensions(input_node, {1}); - // Try using Q-cur to retrieve the token length, so that the translation of SOFT_MAX - // does not depend on the result of the QK MatMul, so that QK matmul + softmax + qkv matmul - // can be fused into SDPA. - if (input_node->get_type_info() == ov::op::v0::MatMul::get_type_info_static()) { - token_len = get_dimensions(input_node->get_input_node_shared_ptr(0), {1}); - } + auto token_len = context.get_input("token_len"); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); std::shared_ptr mask_node_sliced = diff --git a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp index aa6e28b627..1b7ac60271 100644 --- a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +++ b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include @@ -23,13 +22,15 @@ FuseToSDPA::FuseToSDPA() { const auto m_k = ov::pass::pattern::any_input(); const auto m_q = ov::pass::pattern::any_input(); const auto m_qk = ov::pass::pattern::wrap_type({m_q, m_k}); + const auto m_qk_f32 = ov::pass::pattern::wrap_type({m_qk}); const auto m_scale = ov::pass::pattern::any_input(); - const auto m_scaled_qk = ov::pass::pattern::wrap_type({m_qk, m_scale}); + const auto m_scaled_qk = ov::pass::pattern::wrap_type({m_qk_f32, m_scale}); const auto m_mask = ov::pass::pattern::any_input(); const auto m_masked_qk = ov::pass::pattern::wrap_type({m_scaled_qk, m_mask}); const auto m_softmax_qk = ov::pass::pattern::wrap_type({m_masked_qk}); + const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type({m_softmax_qk}); const auto m_v = ov::pass::pattern::any_input(); - const auto m_qkv = ov::pass::pattern::wrap_type({m_softmax_qk, m_v}); + const auto m_qkv = ov::pass::pattern::wrap_type({m_softmax_qk_f16, m_v}); const auto callback = [=](ov::pass::pattern::Matcher& m) { auto& pattern_to_output = m.get_pattern_value_map(); @@ -41,7 +42,9 @@ FuseToSDPA::FuseToSDPA() { auto v_trans = register_new_node(v, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1})); - auto sdpa = std::make_shared(q, k, v_trans, mask, scale, false); + auto mask_f16 = register_new_node(mask, ov::element::f16); + auto scale_f16 = register_new_node(scale, ov::element::f16); + auto sdpa = std::make_shared(q, k, v_trans, mask_f16, scale_f16, false); ov::replace_node(m.get_match_root(), sdpa); ov::copy_runtime_info(m.get_matched_nodes(), sdpa); diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index ed7db61414..daef12fb90 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -22,7 +22,6 @@ #include #include #include -#include #include "ggml-openvino/openvino/node_context.hpp" #include "ggml-openvino/openvino/utils.hpp" @@ -269,12 +268,9 @@ std::shared_ptr TranslateSession::apply_transformations(std::shared_ptr(kv_param_res_pairs); } - // SDPA is even worse on performance manager.register_pass(); manager.run_passes(model); } - auto preprocessor = ov::preprocess::PrePostProcessor(model); - model = preprocessor.build(); return model; }