PERF: favor low precision matmul

This commit is contained in:
Yu, Zijun 2025-05-13 08:42:54 +08:00 committed by Mustafa Cavus
parent 0d009fe61a
commit cdf5370cb5
3 changed files with 21 additions and 20 deletions

View File

@ -33,7 +33,7 @@ public:
return m_decoder->get_input_size(); return m_decoder->get_input_size();
} }
Any get_input_type(size_t index) const { ov::element::Type get_input_type(size_t index) const {
return m_decoder->get_input_type(m_input_names[index]); return m_decoder->get_input_type(m_input_names[index]);
} }

View File

@ -1,19 +1,18 @@
#include <cstddef>
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <openvino/core/node.hpp>
#include <openvino/core/node_output.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/matmul.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/transpose.hpp>
#include <vector> #include <vector>
#include "../node_context.hpp" #include "../node_context.hpp"
#include "../utils.hpp" #include "../utils.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/node_output.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/transpose.hpp"
namespace ov { namespace ov {
namespace frontend { namespace frontend {
@ -25,9 +24,10 @@ OutputVector translate_mulmat(const NodeContext& context) {
bool continuous = context.check_if_continuous(); bool continuous = context.check_if_continuous();
if (continuous) { if (continuous) {
auto src1 = context.get_input(1); auto src0 = context.get_input(0);
auto src0_converted = std::make_shared<ov::op::v1::ConvertLike>(context.get_input(0), src1); auto src1 = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
auto result = std::make_shared<ov::op::v0::MatMul>(src1, src0_converted, false, true); auto result_lp = std::make_shared<ov::op::v0::MatMul>(src1, src0, false, true);
auto result = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
return {result}; return {result};
} else { } else {
/* /*
@ -94,8 +94,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
B = src0_slice; B = src0_slice;
} }
A = context.get_input(1); A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
B = std::make_shared<ov::op::v1::ConvertLike>(B, A);
int64_t num_heads = context.get_input_shape(1).to_shape()[0]; int64_t num_heads = context.get_input_shape(1).to_shape()[0];
int64_t num_heads_kv = src0_shape[0]; int64_t num_heads_kv = src0_shape[0];
@ -116,10 +115,12 @@ OutputVector translate_mulmat(const NodeContext& context) {
B = std::make_shared<ov::op::v1::Reshape>(B, new_B_shape, false); B = std::make_shared<ov::op::v1::Reshape>(B, new_B_shape, false);
} }
auto result = std::make_shared<ov::op::v0::MatMul>(A, B, false, true); auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
auto result = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
return {result}; return {result};
} }
}; }
} // namespace op } // namespace op
} // namespace ggml } // namespace ggml

View File

@ -49,7 +49,7 @@ OutputVector translate_soft_max(const NodeContext& context) {
if (context.get_input_size() == 2) { if (context.get_input_size() == 2) {
// Calculate mask then softmax // Calculate mask then softmax
auto mask_node = context.get_input(1); auto mask_node = context.get_input(1);
ov::element::Type mask_type = (context.get_input_type(1)).as<ov::element::Type>(); ov::element::Type mask_type = context.get_input_type(1);
if (mask_type == ov::element::f16) { if (mask_type == ov::element::f16) {
// Convert f16 to f32 // Convert f16 to f32
mask_node = std::make_shared<ov::op::v0::Convert>(mask_node, ov::element::f32); mask_node = std::make_shared<ov::op::v0::Convert>(mask_node, ov::element::f32);
@ -80,7 +80,7 @@ OutputVector translate_soft_max(const NodeContext& context) {
auto res = std::make_shared<ov::op::v8::Softmax>(input_node, 0); auto res = std::make_shared<ov::op::v8::Softmax>(input_node, 0);
return {res}; return {res};
} }
}; }
} // namespace op } // namespace op
} // namespace ggml } // namespace ggml