matmul in fp32
This commit is contained in:
parent
9cf56d6837
commit
01cdf4a9cc
|
|
@ -212,6 +212,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
|
|||
} else {
|
||||
m_op_case = 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ private:
|
|||
std::vector<std::string> m_output_names;
|
||||
std::string m_op_name;
|
||||
mutable std::string m_name;
|
||||
int m_op_case;
|
||||
int m_op_case = 0;
|
||||
std::vector<std::pair<std::string, std::string>> m_op_node_name;
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> m_model_inputs;
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> m_model_extra_inputs;
|
||||
|
|
|
|||
|
|
@ -29,15 +29,8 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
|||
ov::Output<Node> res;
|
||||
ov::Output<ov::Node> B = context.get_input(0);
|
||||
ov::Output<ov::Node> A = context.get_input(1);
|
||||
if (context.get_op_case() == 1) {
|
||||
if (context.get_input_type(0) == ov::element::f16) {
|
||||
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), ov::element::f32);
|
||||
}
|
||||
if (context.get_input_type(1) == ov::element::f16) {
|
||||
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), ov::element::f32);
|
||||
}
|
||||
} else {
|
||||
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
|
||||
if (context.get_input_type(0) != context.get_input_type(1)) {
|
||||
B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_input_type(1));
|
||||
}
|
||||
|
||||
auto B_shape = context.get_input_shape(0).to_shape();
|
||||
|
|
@ -72,8 +65,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
|||
A = Z;
|
||||
}
|
||||
|
||||
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
|
||||
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
|
||||
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
|
||||
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -57,11 +57,8 @@ OutputVector translate_soft_max(const NodeContext& context) {
|
|||
// Try using Q-cur to retrieve the token length, so that the translation of SOFT_MAX
|
||||
// does not depend on the result of the QK MatMul, so that QK matmul + softmax + qkv matmul
|
||||
// can be fused into SDPA.
|
||||
if (input_node->get_type_info() == ov::op::v0::Convert::get_type_info_static()) {
|
||||
auto qk = input_node->get_input_node_shared_ptr(0);
|
||||
if (qk->get_type_info() == ov::op::v0::MatMul::get_type_info_static()) {
|
||||
token_len = get_dimensions(qk->get_input_node_shared_ptr(0), {1});
|
||||
}
|
||||
if (input_node->get_type_info() == ov::op::v0::MatMul::get_type_info_static()) {
|
||||
token_len = get_dimensions(input_node->get_input_node_shared_ptr(0), {1});
|
||||
}
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <openvino/op/softmax.hpp>
|
||||
#include <openvino/op/transpose.hpp>
|
||||
#include <openvino/pass/pattern/op/label.hpp>
|
||||
#include <openvino/pass/pattern/op/optional.hpp>
|
||||
#include <openvino/pass/pattern/op/pattern.hpp>
|
||||
#include <openvino/pass/pattern/op/wrap_type.hpp>
|
||||
|
||||
|
|
@ -22,15 +23,13 @@ FuseToSDPA::FuseToSDPA() {
|
|||
const auto m_k = ov::pass::pattern::any_input();
|
||||
const auto m_q = ov::pass::pattern::any_input();
|
||||
const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});
|
||||
const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});
|
||||
const auto m_scale = ov::pass::pattern::any_input();
|
||||
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32, m_scale});
|
||||
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk, m_scale});
|
||||
const auto m_mask = ov::pass::pattern::any_input();
|
||||
const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});
|
||||
const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});
|
||||
const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});
|
||||
const auto m_v = ov::pass::pattern::any_input();
|
||||
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16, m_v});
|
||||
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk, m_v});
|
||||
|
||||
const auto callback = [=](ov::pass::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
|
@ -42,9 +41,7 @@ FuseToSDPA::FuseToSDPA() {
|
|||
|
||||
auto v_trans =
|
||||
register_new_node<ov::op::v1::Transpose>(v, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1}));
|
||||
auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16);
|
||||
auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16);
|
||||
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask_f16, scale_f16, false);
|
||||
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask, scale, false);
|
||||
|
||||
ov::replace_node(m.get_match_root(), sdpa);
|
||||
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@
|
|||
#include <openvino/op/unsqueeze.hpp>
|
||||
#include <openvino/pass/constant_folding.hpp>
|
||||
#include <openvino/pass/make_stateful.hpp>
|
||||
#include <transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp>
|
||||
#include <openvino/core/preprocess/pre_post_process.hpp>
|
||||
|
||||
#include "ggml-openvino/openvino/node_context.hpp"
|
||||
#include "ggml-openvino/openvino/utils.hpp"
|
||||
|
|
@ -254,22 +254,25 @@ std::shared_ptr<Model> TranslateSession::translate_graph(const frontend::InputMo
|
|||
return resulting_model;
|
||||
}
|
||||
|
||||
void TranslateSession::apply_transformations(const std::shared_ptr<Model>& model) {
|
||||
std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<Model> model) {
|
||||
auto ggml_model_decoder = std::dynamic_pointer_cast<InputModel>(m_input_model)->get_model_decoder();
|
||||
{
|
||||
ov::pass::Manager manager;
|
||||
manager.set_per_pass_validation(true);
|
||||
|
||||
ov::pass::Manager manager;
|
||||
manager.set_per_pass_validation(true);
|
||||
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
if (!ggml_model_decoder->is_static()) {
|
||||
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
|
||||
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
|
||||
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
|
||||
}
|
||||
|
||||
if (!ggml_model_decoder->is_static()) {
|
||||
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
|
||||
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
|
||||
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
|
||||
// SDPA is even worse on performance
|
||||
// manager.register_pass<pass::FuseToSDPA>();
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
manager.register_pass<pass::FuseToSDPA>();
|
||||
manager.run_passes(model);
|
||||
auto preprocessor = ov::preprocess::PrePostProcessor(model);
|
||||
model = preprocessor.build();
|
||||
return model;
|
||||
}
|
||||
|
||||
} // namespace ggml
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ public:
|
|||
std::shared_ptr<Model> translate_graph(const frontend::InputModel::Ptr& input_model);
|
||||
|
||||
private:
|
||||
void apply_transformations(const std::shared_ptr<Model>& model);
|
||||
std::shared_ptr<Model> apply_transformations(std::shared_ptr<Model> model);
|
||||
const frontend::InputModel::Ptr m_input_model;
|
||||
const std::unordered_map<std::string, CreatorFunction>& m_translator_map;
|
||||
std::shared_ptr<Model> m_ov_model;
|
||||
|
|
|
|||
Loading…
Reference in New Issue