Perf: RMS fused to OV internal RMS op
This commit is contained in:
parent
a7b611bc93
commit
14c8a85c32
|
|
@ -3,6 +3,7 @@
|
||||||
#include <openvino/op/constant.hpp>
|
#include <openvino/op/constant.hpp>
|
||||||
#include <openvino/op/divide.hpp>
|
#include <openvino/op/divide.hpp>
|
||||||
#include <openvino/op/multiply.hpp>
|
#include <openvino/op/multiply.hpp>
|
||||||
|
#include <openvino/op/power.hpp>
|
||||||
#include <openvino/op/reduce_mean.hpp>
|
#include <openvino/op/reduce_mean.hpp>
|
||||||
#include <openvino/op/sqrt.hpp>
|
#include <openvino/op/sqrt.hpp>
|
||||||
|
|
||||||
|
|
@ -19,18 +20,17 @@ OutputVector translate_rms_norm(const NodeContext& context) {
|
||||||
num_inputs_check(context, 1, 1);
|
num_inputs_check(context, 1, 1);
|
||||||
|
|
||||||
auto input_node = context.get_input(0);
|
auto input_node = context.get_input(0);
|
||||||
auto square = std::make_shared<ov::op::v1::Multiply>(input_node, input_node);
|
auto square = std::make_shared<ov::op::v1::Power>(
|
||||||
|
input_node, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {2.0f}));
|
||||||
|
|
||||||
auto mean =
|
auto mean = std::make_shared<ov::op::v1::ReduceMean>(
|
||||||
std::make_shared<ov::op::v1::ReduceMean>(square,
|
square, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}), true);
|
||||||
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}),
|
|
||||||
true);
|
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, context.get_output_op_params(0), sizeof(float));
|
memcpy(&eps, context.get_output_op_params(0), sizeof(float));
|
||||||
|
|
||||||
auto rms = std::make_shared<ov::op::v0::Sqrt>(
|
auto rms = std::make_shared<ov::op::v0::Sqrt>(
|
||||||
std::make_shared<ov::op::v1::Add>(mean, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{}, {eps})));
|
std::make_shared<ov::op::v1::Add>(mean, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {eps})));
|
||||||
|
|
||||||
auto reciprocal =
|
auto reciprocal =
|
||||||
std::make_shared<ov::op::v1::Divide>(ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {1.0f}), rms);
|
std::make_shared<ov::op::v1::Divide>(ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {1.0f}), rms);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue