matmul in fp32

This commit is contained in:
Yu, Zijun 2025-07-29 14:07:03 +08:00 committed by Mustafa Cavus
parent 9cf56d6837
commit 01cdf4a9cc
7 changed files with 28 additions and 38 deletions

View File

@ -212,6 +212,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
} else {
m_op_case = 1;
}
break;
}
default:
break;

View File

@ -139,7 +139,7 @@ private:
std::vector<std::string> m_output_names;
std::string m_op_name;
mutable std::string m_name;
int m_op_case;
int m_op_case = 0;
std::vector<std::pair<std::string, std::string>> m_op_node_name;
std::map<std::string, std::shared_ptr<ov::Node>> m_model_inputs;
std::map<std::string, std::shared_ptr<ov::Node>> m_model_extra_inputs;

View File

@ -29,15 +29,8 @@ OutputVector translate_mulmat(const NodeContext& context) {
ov::Output<Node> res;
ov::Output<ov::Node> B = context.get_input(0);
ov::Output<ov::Node> A = context.get_input(1);
if (context.get_op_case() == 1) {
if (context.get_input_type(0) == ov::element::f16) {
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), ov::element::f32);
}
if (context.get_input_type(1) == ov::element::f16) {
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), ov::element::f32);
}
} else {
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
if (context.get_input_type(0) != context.get_input_type(1)) {
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_input_type(1));
}
auto B_shape = context.get_input_shape(0).to_shape();
@ -72,8 +65,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
A = Z;
}
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -57,11 +57,8 @@ OutputVector translate_soft_max(const NodeContext& context) {
// Try using Q-cur to retrieve the token length, so that the translation of SOFT_MAX
// does not depend on the result of the QK MatMul, so that QK matmul + softmax + qkv matmul
// can be fused into SDPA.
if (input_node->get_type_info() == ov::op::v0::Convert::get_type_info_static()) {
auto qk = input_node->get_input_node_shared_ptr(0);
if (qk->get_type_info() == ov::op::v0::MatMul::get_type_info_static()) {
token_len = get_dimensions(qk->get_input_node_shared_ptr(0), {1});
}
if (input_node->get_type_info() == ov::op::v0::MatMul::get_type_info_static()) {
token_len = get_dimensions(input_node->get_input_node_shared_ptr(0), {1});
}
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});

View File

@ -10,6 +10,7 @@
#include <openvino/op/softmax.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/pass/pattern/op/label.hpp>
#include <openvino/pass/pattern/op/optional.hpp>
#include <openvino/pass/pattern/op/pattern.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
@ -22,15 +23,13 @@ FuseToSDPA::FuseToSDPA() {
const auto m_k = ov::pass::pattern::any_input();
const auto m_q = ov::pass::pattern::any_input();
const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});
const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});
const auto m_scale = ov::pass::pattern::any_input();
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32, m_scale});
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk, m_scale});
const auto m_mask = ov::pass::pattern::any_input();
const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});
const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});
const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});
const auto m_v = ov::pass::pattern::any_input();
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16, m_v});
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk, m_v});
const auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
@ -42,9 +41,7 @@ FuseToSDPA::FuseToSDPA() {
auto v_trans =
register_new_node<ov::op::v1::Transpose>(v, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1}));
auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16);
auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16);
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask_f16, scale_f16, false);
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask, scale, false);
ov::replace_node(m.get_match_root(), sdpa);
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);

View File

@ -22,7 +22,7 @@
#include <openvino/op/unsqueeze.hpp>
#include <openvino/pass/constant_folding.hpp>
#include <openvino/pass/make_stateful.hpp>
#include <transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp>
#include <openvino/core/preprocess/pre_post_process.hpp>
#include "ggml-openvino/openvino/node_context.hpp"
#include "ggml-openvino/openvino/utils.hpp"
@ -254,22 +254,25 @@ std::shared_ptr<Model> TranslateSession::translate_graph(const frontend::InputMo
return resulting_model;
}
void TranslateSession::apply_transformations(const std::shared_ptr<Model>& model) {
std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<Model> model) {
auto ggml_model_decoder = std::dynamic_pointer_cast<InputModel>(m_input_model)->get_model_decoder();
{
ov::pass::Manager manager;
manager.set_per_pass_validation(true);
ov::pass::Manager manager;
manager.set_per_pass_validation(true);
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
manager.register_pass<ov::pass::ConstantFolding>();
if (!ggml_model_decoder->is_static()) {
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
}
if (!ggml_model_decoder->is_static()) {
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
// SDPA is even worse on performance
// manager.register_pass<pass::FuseToSDPA>();
manager.run_passes(model);
}
manager.register_pass<pass::FuseToSDPA>();
manager.run_passes(model);
auto preprocessor = ov::preprocess::PrePostProcessor(model);
model = preprocessor.build();
return model;
}
} // namespace ggml

View File

@ -16,7 +16,7 @@ public:
std::shared_ptr<Model> translate_graph(const frontend::InputModel::Ptr& input_model);
private:
void apply_transformations(const std::shared_ptr<Model>& model);
std::shared_ptr<Model> apply_transformations(std::shared_ptr<Model> model);
const frontend::InputModel::Ptr m_input_model;
const std::unordered_map<std::string, CreatorFunction>& m_translator_map;
std::shared_ptr<Model> m_ov_model;