Fix Phi3 SwiGLU and SoftMax
This commit is contained in:
parent
0fa7a5efef
commit
3533c14cf6
|
|
@ -1,6 +1,11 @@
|
|||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <openvino/core/node_output.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/multiply.hpp>
|
||||
#include <openvino/op/sigmoid.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/split.hpp>
|
||||
|
||||
#include "../node_context.hpp"
|
||||
#include "../op_table.hpp"
|
||||
|
|
@ -12,13 +17,23 @@ namespace ggml {
|
|||
namespace op {
|
||||
|
||||
OutputVector translate_glu_swiglu(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
num_inputs_check(context, 1, 2);
|
||||
|
||||
auto src1 = context.get_input(0);
|
||||
auto src2 = context.get_input(1);
|
||||
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src1);
|
||||
auto silu = std::make_shared<ov::op::v1::Multiply>(src1, sigmoid);
|
||||
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src2);
|
||||
ov::Output<ov::Node> src0;
|
||||
ov::Output<ov::Node> src1;
|
||||
if (context.get_input_size() == 2) {
|
||||
src0 = context.get_input(0);
|
||||
src1 = context.get_input(1);
|
||||
} else {
|
||||
auto combined = context.get_input(0);
|
||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {2});
|
||||
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
|
||||
src0 = split->output(0);
|
||||
src1 = split->output(1);
|
||||
}
|
||||
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src0);
|
||||
auto silu = std::make_shared<ov::op::v1::Multiply>(src0, sigmoid);
|
||||
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src1);
|
||||
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,12 +43,8 @@ OutputVector translate_soft_max(const NodeContext& context) {
|
|||
const float slope =
|
||||
(max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
|
||||
|
||||
std::shared_ptr<ov::Node> scaled_input;
|
||||
if (scale != 1.0f) {
|
||||
auto scale_node =
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale});
|
||||
scaled_input = std::make_shared<ov::op::v1::Multiply>(input_node, scale_node);
|
||||
}
|
||||
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale});
|
||||
auto scaled_input = std::make_shared<ov::op::v1::Multiply>(input_node, scale_node);
|
||||
|
||||
auto mask_node = context.get_input(1);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue