Fix Phi3 SwiGLU and SoftMax

This commit is contained in:
Yu, Zijun 2025-07-09 10:16:06 +08:00 committed by Mustafa Cavus
parent 0fa7a5efef
commit 3533c14cf6
2 changed files with 23 additions and 12 deletions

View File

@ -1,6 +1,11 @@
#include <cstdint>
#include <memory>
#include <openvino/core/node_output.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/sigmoid.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/split.hpp>
#include "../node_context.hpp"
#include "../op_table.hpp"
@ -12,13 +17,23 @@ namespace ggml {
namespace op {
OutputVector translate_glu_swiglu(const NodeContext& context) {
num_inputs_check(context, 2, 2);
num_inputs_check(context, 1, 2);
auto src1 = context.get_input(0);
auto src2 = context.get_input(1);
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src1);
auto silu = std::make_shared<ov::op::v1::Multiply>(src1, sigmoid);
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src2);
ov::Output<ov::Node> src0;
ov::Output<ov::Node> src1;
if (context.get_input_size() == 2) {
src0 = context.get_input(0);
src1 = context.get_input(1);
} else {
auto combined = context.get_input(0);
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {2});
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
src0 = split->output(0);
src1 = split->output(1);
}
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src0);
auto silu = std::make_shared<ov::op::v1::Multiply>(src0, sigmoid);
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src1);
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -43,12 +43,8 @@ OutputVector translate_soft_max(const NodeContext& context) {
const float slope =
(max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
std::shared_ptr<ov::Node> scaled_input;
if (scale != 1.0f) {
auto scale_node =
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale});
scaled_input = std::make_shared<ov::op::v1::Multiply>(input_node, scale_node);
}
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale});
auto scaled_input = std::make_shared<ov::op::v1::Multiply>(input_node, scale_node);
auto mask_node = context.get_input(1);