PERF: favor low precision matmul

This commit is contained in:
Yu, Zijun 2025-05-13 08:42:54 +08:00 committed by Mustafa Cavus
parent 0d009fe61a
commit cdf5370cb5
3 changed files with 21 additions and 20 deletions

View File

@ -33,7 +33,7 @@ public:
return m_decoder->get_input_size();
}
Any get_input_type(size_t index) const {
ov::element::Type get_input_type(size_t index) const {
return m_decoder->get_input_type(m_input_names[index]);
}

View File

@ -1,19 +1,18 @@
#include <cstddef>
#include <cstdint>
#include <memory>
#include <openvino/core/node.hpp>
#include <openvino/core/node_output.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/matmul.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/transpose.hpp>
#include <vector>
#include "../node_context.hpp"
#include "../utils.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/node_output.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/transpose.hpp"
namespace ov {
namespace frontend {
@ -25,9 +24,10 @@ OutputVector translate_mulmat(const NodeContext& context) {
bool continuous = context.check_if_continuous();
if (continuous) {
auto src1 = context.get_input(1);
auto src0_converted = std::make_shared<ov::op::v1::ConvertLike>(context.get_input(0), src1);
auto result = std::make_shared<ov::op::v0::MatMul>(src1, src0_converted, false, true);
auto src0 = context.get_input(0);
auto src1 = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
auto result_lp = std::make_shared<ov::op::v0::MatMul>(src1, src0, false, true);
auto result = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
return {result};
} else {
/*
@ -94,8 +94,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
B = src0_slice;
}
A = context.get_input(1);
B = std::make_shared<ov::op::v1::ConvertLike>(B, A);
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
int64_t num_heads = context.get_input_shape(1).to_shape()[0];
int64_t num_heads_kv = src0_shape[0];
@ -116,10 +115,12 @@ OutputVector translate_mulmat(const NodeContext& context) {
B = std::make_shared<ov::op::v1::Reshape>(B, new_B_shape, false);
}
auto result = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
auto result = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
return {result};
}
};
}
} // namespace op
} // namespace ggml

View File

@ -49,7 +49,7 @@ OutputVector translate_soft_max(const NodeContext& context) {
if (context.get_input_size() == 2) {
// Calculate mask then softmax
auto mask_node = context.get_input(1);
ov::element::Type mask_type = (context.get_input_type(1)).as<ov::element::Type>();
ov::element::Type mask_type = context.get_input_type(1);
if (mask_type == ov::element::f16) {
// Convert f16 to f32
mask_node = std::make_shared<ov::op::v0::Convert>(mask_node, ov::element::f32);
@ -80,7 +80,7 @@ OutputVector translate_soft_max(const NodeContext& context) {
auto res = std::make_shared<ov::op::v8::Softmax>(input_node, 0);
return {res};
}
};
}
} // namespace op
} // namespace ggml