PERF: favor low precision matmul
This commit is contained in:
parent
0d009fe61a
commit
cdf5370cb5
|
|
@ -33,7 +33,7 @@ public:
|
||||||
return m_decoder->get_input_size();
|
return m_decoder->get_input_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
Any get_input_type(size_t index) const {
|
ov::element::Type get_input_type(size_t index) const {
|
||||||
return m_decoder->get_input_type(m_input_names[index]);
|
return m_decoder->get_input_type(m_input_names[index]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,18 @@
|
||||||
#include <cstddef>
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <openvino/core/node.hpp>
|
||||||
|
#include <openvino/core/node_output.hpp>
|
||||||
|
#include <openvino/op/concat.hpp>
|
||||||
|
#include <openvino/op/constant.hpp>
|
||||||
|
#include <openvino/op/convert.hpp>
|
||||||
|
#include <openvino/op/matmul.hpp>
|
||||||
|
#include <openvino/op/reshape.hpp>
|
||||||
|
#include <openvino/op/slice.hpp>
|
||||||
|
#include <openvino/op/transpose.hpp>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../node_context.hpp"
|
#include "../node_context.hpp"
|
||||||
#include "../utils.hpp"
|
#include "../utils.hpp"
|
||||||
#include "openvino/core/node.hpp"
|
|
||||||
#include "openvino/core/node_output.hpp"
|
|
||||||
#include "openvino/op/concat.hpp"
|
|
||||||
#include "openvino/op/constant.hpp"
|
|
||||||
#include "openvino/op/convert_like.hpp"
|
|
||||||
#include "openvino/op/matmul.hpp"
|
|
||||||
#include "openvino/op/reshape.hpp"
|
|
||||||
#include "openvino/op/slice.hpp"
|
|
||||||
#include "openvino/op/transpose.hpp"
|
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
namespace frontend {
|
namespace frontend {
|
||||||
|
|
@ -25,9 +24,10 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
||||||
|
|
||||||
bool continuous = context.check_if_continuous();
|
bool continuous = context.check_if_continuous();
|
||||||
if (continuous) {
|
if (continuous) {
|
||||||
auto src1 = context.get_input(1);
|
auto src0 = context.get_input(0);
|
||||||
auto src0_converted = std::make_shared<ov::op::v1::ConvertLike>(context.get_input(0), src1);
|
auto src1 = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
|
||||||
auto result = std::make_shared<ov::op::v0::MatMul>(src1, src0_converted, false, true);
|
auto result_lp = std::make_shared<ov::op::v0::MatMul>(src1, src0, false, true);
|
||||||
|
auto result = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
|
||||||
return {result};
|
return {result};
|
||||||
} else {
|
} else {
|
||||||
/*
|
/*
|
||||||
|
|
@ -94,8 +94,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
||||||
B = src0_slice;
|
B = src0_slice;
|
||||||
}
|
}
|
||||||
|
|
||||||
A = context.get_input(1);
|
A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
|
||||||
B = std::make_shared<ov::op::v1::ConvertLike>(B, A);
|
|
||||||
|
|
||||||
int64_t num_heads = context.get_input_shape(1).to_shape()[0];
|
int64_t num_heads = context.get_input_shape(1).to_shape()[0];
|
||||||
int64_t num_heads_kv = src0_shape[0];
|
int64_t num_heads_kv = src0_shape[0];
|
||||||
|
|
@ -116,10 +115,12 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
||||||
B = std::make_shared<ov::op::v1::Reshape>(B, new_B_shape, false);
|
B = std::make_shared<ov::op::v1::Reshape>(B, new_B_shape, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
|
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, true);
|
||||||
|
auto result = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
|
||||||
|
|
||||||
return {result};
|
return {result};
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace ggml
|
} // namespace ggml
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ OutputVector translate_soft_max(const NodeContext& context) {
|
||||||
if (context.get_input_size() == 2) {
|
if (context.get_input_size() == 2) {
|
||||||
// Calculate mask then softmax
|
// Calculate mask then softmax
|
||||||
auto mask_node = context.get_input(1);
|
auto mask_node = context.get_input(1);
|
||||||
ov::element::Type mask_type = (context.get_input_type(1)).as<ov::element::Type>();
|
ov::element::Type mask_type = context.get_input_type(1);
|
||||||
if (mask_type == ov::element::f16) {
|
if (mask_type == ov::element::f16) {
|
||||||
// Convert f16 to f32
|
// Convert f16 to f32
|
||||||
mask_node = std::make_shared<ov::op::v0::Convert>(mask_node, ov::element::f32);
|
mask_node = std::make_shared<ov::op::v0::Convert>(mask_node, ov::element::f32);
|
||||||
|
|
@ -80,7 +80,7 @@ OutputVector translate_soft_max(const NodeContext& context) {
|
||||||
auto res = std::make_shared<ov::op::v8::Softmax>(input_node, 0);
|
auto res = std::make_shared<ov::op::v8::Softmax>(input_node, 0);
|
||||||
return {res};
|
return {res};
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace ggml
|
} // namespace ggml
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue