Add SwiGLU

This commit is contained in:
Yu, Zijun 2025-07-03 11:03:40 +08:00 committed by Mustafa Cavus
parent 4c582ac7a3
commit 73ee84fffe
6 changed files with 123 additions and 70 deletions

View File

@ -4,6 +4,7 @@ AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false AlignConsecutiveDeclarations: false
ReferenceAlignment: Left ReferenceAlignment: Left
PointerAlignment: Left PointerAlignment: Left
Cpp11BracedListStyle: true
Language: Cpp Language: Cpp
AlignAfterOpenBracket: Align AlignAfterOpenBracket: Align
@ -65,7 +66,6 @@ CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false CompactNamespaces: false
ConstructorInitializerIndentWidth: 4 ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4 ContinuationIndentWidth: 4
Cpp11BracedListStyle: false
DerivePointerAlignment: false DerivePointerAlignment: false
DisableFormat: false DisableFormat: false
EmptyLineBeforeAccessModifier: Leave EmptyLineBeforeAccessModifier: Leave

View File

@ -563,43 +563,58 @@ void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecode
} }
const std::string& GgmlOvDecoder::get_op_type() const { const std::string& GgmlOvDecoder::get_op_type() const {
static const std::map<ggml_op, std::string> opTypeMap = { static const std::map<ggml_op, std::string> ops = {
{GGML_OP_ACC, "GGML_OP_ACC"}, {GGML_OP_ADD, "GGML_OP_ADD"}, {GGML_OP_ACC, "GGML_OP_ACC" },
{GGML_OP_ADD1, "GGML_OP_ADD1"}, {GGML_OP_CONT, "GGML_OP_CONT"}, {GGML_OP_ADD, "GGML_OP_ADD" },
{GGML_OP_CPY, "GGML_OP_CPY"}, {GGML_OP_DIV, "GGML_OP_DIV"}, {GGML_OP_ADD1, "GGML_OP_ADD1" },
{GGML_OP_DUP, "GGML_OP_DUP"}, {GGML_OP_GET_ROWS, "GGML_OP_GET_ROWS"}, {GGML_OP_CONT, "GGML_OP_CONT" },
{GGML_OP_MUL, "GGML_OP_MUL"}, {GGML_OP_MUL_MAT, "GGML_OP_MUL_MAT"}, {GGML_OP_CPY, "GGML_OP_CPY" },
{GGML_OP_PERMUTE, "GGML_OP_PERMUTE"}, {GGML_OP_RESHAPE, "GGML_OP_RESHAPE"}, {GGML_OP_DIV, "GGML_OP_DIV" },
{GGML_OP_RMS_NORM, "GGML_OP_RMS_NORM"}, {GGML_OP_ROPE, "GGML_OP_ROPE"}, {GGML_OP_DUP, "GGML_OP_DUP" },
{GGML_OP_SCALE, "GGML_OP_SCALE"}, {GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX"}, {GGML_OP_GET_ROWS, "GGML_OP_GET_ROWS" },
{GGML_OP_SUB, "GGML_OP_SUB"}, {GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE"}, {GGML_OP_MUL, "GGML_OP_MUL" },
{GGML_OP_UNARY, "GGML_OP_UNARY"}, {GGML_OP_VIEW, "GGML_OP_VIEW"}}; {GGML_OP_MUL_MAT, "GGML_OP_MUL_MAT" },
static const std::map<ggml_unary_op, std::string> unaryOpTypeMap = { {GGML_OP_PERMUTE, "GGML_OP_PERMUTE" },
{GGML_UNARY_OP_ABS, "GGML_UNARY_OP_ABS"}, {GGML_OP_RESHAPE, "GGML_OP_RESHAPE" },
{GGML_UNARY_OP_SGN, "GGML_UNARY_OP_SGN"}, {GGML_OP_RMS_NORM, "GGML_OP_RMS_NORM" },
{GGML_UNARY_OP_NEG, "GGML_UNARY_OP_NEG"}, {GGML_OP_ROPE, "GGML_OP_ROPE" },
{GGML_UNARY_OP_STEP, "GGML_UNARY_OP_STEP"}, {GGML_OP_SCALE, "GGML_OP_SCALE" },
{GGML_UNARY_OP_TANH, "GGML_UNARY_OP_TANH"}, {GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX" },
{GGML_UNARY_OP_ELU, "GGML_UNARY_OP_ELU"}, {GGML_OP_SUB, "GGML_OP_SUB" },
{GGML_UNARY_OP_RELU, "GGML_UNARY_OP_RELU"}, {GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE"},
{GGML_UNARY_OP_SIGMOID, "GGML_UNARY_OP_SIGMOID"}, {GGML_OP_VIEW, "GGML_OP_VIEW" }
{GGML_UNARY_OP_GELU, "GGML_UNARY_OP_GELU"}, };
{GGML_UNARY_OP_GELU_QUICK, "GGML_UNARY_OP_GELU_QUICK"}, static const std::map<ggml_unary_op, std::string> unary_ops = {
{GGML_UNARY_OP_SILU, "GGML_UNARY_OP_SILU"}, {GGML_UNARY_OP_ABS, "GGML_UNARY_OP_ABS" },
{GGML_UNARY_OP_HARDSWISH, "GGML_UNARY_OP_HARDSWISH"}, {GGML_UNARY_OP_SGN, "GGML_UNARY_OP_SGN" },
{GGML_UNARY_OP_NEG, "GGML_UNARY_OP_NEG" },
{GGML_UNARY_OP_STEP, "GGML_UNARY_OP_STEP" },
{GGML_UNARY_OP_TANH, "GGML_UNARY_OP_TANH" },
{GGML_UNARY_OP_ELU, "GGML_UNARY_OP_ELU" },
{GGML_UNARY_OP_RELU, "GGML_UNARY_OP_RELU" },
{GGML_UNARY_OP_SIGMOID, "GGML_UNARY_OP_SIGMOID" },
{GGML_UNARY_OP_GELU, "GGML_UNARY_OP_GELU" },
{GGML_UNARY_OP_GELU_QUICK, "GGML_UNARY_OP_GELU_QUICK" },
{GGML_UNARY_OP_SILU, "GGML_UNARY_OP_SILU" },
{GGML_UNARY_OP_HARDSWISH, "GGML_UNARY_OP_HARDSWISH" },
{GGML_UNARY_OP_HARDSIGMOID, "GGML_UNARY_OP_HARDSIGMOID"}, {GGML_UNARY_OP_HARDSIGMOID, "GGML_UNARY_OP_HARDSIGMOID"},
{GGML_UNARY_OP_EXP, "GGML_UNARY_OP_EXP"}, {GGML_UNARY_OP_EXP, "GGML_UNARY_OP_EXP" },
{GGML_UNARY_OP_COUNT, "GGML_UNARY_OP_COUNT"}}; {GGML_UNARY_OP_COUNT, "GGML_UNARY_OP_COUNT" }
auto it = opTypeMap.find(m_node->op); };
if (it != opTypeMap.end()) { static const std::map<ggml_glu_op, std::string> glu_ops = {
if (it->first == GGML_OP_UNARY) { {GGML_GLU_OP_SWIGLU, "GGML_GLU_OP_SWIGLU"},
auto unary_it = unaryOpTypeMap.find(ggml_get_unary_op(m_node)); {GGML_GLU_OP_GEGLU, "GGML_GLU_OP_GEGLU" },
if (unary_it != unaryOpTypeMap.end()) { {GGML_GLU_OP_REGLU, "GGML_GLU_OP_REGLU" }
return unary_it->second; };
}
} switch (m_node->op) {
return it->second; case GGML_OP_UNARY:
return unary_ops.at(ggml_get_unary_op(m_node));
case GGML_OP_GLU:
return glu_ops.at(ggml_get_glu_op(m_node));
default:
return ops.at(m_node->op);
} }
static const std::string unknown_op = "UNKNOWN_OP"; static const std::string unknown_op = "UNKNOWN_GGML_OP";
return unknown_op; return unknown_op;
} }

View File

@ -237,21 +237,29 @@ static ggml_backend_buffer_t ggml_backend_openvino_device_buffer_from_host_ptr(g
static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
GGML_ASSERT(dev->reg != nullptr); GGML_ASSERT(dev->reg != nullptr);
static const std::set<ggml_op> supported_ops{ static const std::set<ggml_op> supported_ops{GGML_OP_NONE, GGML_OP_ADD, GGML_OP_MUL, GGML_OP_MUL_MAT,
GGML_OP_ADD, GGML_OP_MUL, GGML_OP_MUL_MAT, GGML_OP_VIEW, GGML_OP_VIEW, GGML_OP_CONT, GGML_OP_CPY, GGML_OP_RESHAPE,
GGML_OP_CONT, GGML_OP_CPY, GGML_OP_RESHAPE, GGML_OP_PERMUTE, GGML_OP_PERMUTE, GGML_OP_TRANSPOSE, GGML_OP_GET_ROWS, GGML_OP_ROPE,
GGML_OP_TRANSPOSE, GGML_OP_GET_ROWS, GGML_OP_ROPE, GGML_OP_RMS_NORM, GGML_OP_RMS_NORM, GGML_OP_SCALE, GGML_OP_SOFT_MAX};
GGML_OP_SCALE, GGML_OP_SOFT_MAX, static const std::set<ggml_unary_op> supported_unary_ops{
}; GGML_UNARY_OP_SILU,
static const std::set<ggml_unary_op> supported_unary_ops{ };
GGML_UNARY_OP_SILU, static const std::set<ggml_glu_op> supported_glu_ops{
}; GGML_GLU_OP_SWIGLU,
};
if (op->op == GGML_OP_UNARY) { auto res = false;
return supported_unary_ops.find(ggml_get_unary_op(op)) != switch (op->op) {
supported_unary_ops.end(); case GGML_OP_UNARY:
} res = supported_unary_ops.find(ggml_get_unary_op(op)) != supported_unary_ops.end();
return supported_ops.find(op->op) != supported_ops.end(); break;
case GGML_OP_GLU:
res = supported_glu_ops.find(ggml_get_glu_op(op)) != supported_glu_ops.end();
break;
default:
res = supported_ops.find(op->op) != supported_ops.end();
}
return res;
} }
static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {

View File

@ -0,0 +1,29 @@
#include <openvino/core/node_output.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/sigmoid.hpp>
#include "../node_context.hpp"
#include "../op_table.hpp"
#include "../utils.hpp"
namespace ov {
namespace frontend {
namespace ggml {
namespace op {
OutputVector translate_glu_swiglu(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto src1 = context.get_input(0);
auto src2 = context.get_input(1);
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src1);
auto silu = std::make_shared<ov::op::v1::Multiply>(src1, sigmoid);
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src2);
return rename_outputs_with_suffix({res}, context.get_name());
}
} // namespace op
} // namespace ggml
} // namespace frontend
} // namespace ov

View File

@ -16,24 +16,25 @@ namespace ggml {
std::unordered_map<std::string, CreatorFunction> get_supported_ops() { std::unordered_map<std::string, CreatorFunction> get_supported_ops() {
using namespace ov::op; using namespace ov::op;
return { return {
{ "GGML_OP_ADD", op::translate_1to1_match_2_inputs<v1::Add> }, {"GGML_OP_ADD", op::translate_1to1_match_2_inputs<v1::Add> },
{ "GGML_OP_ADD1", op::translate_1to1_match_2_inputs<v1::Add> }, {"GGML_OP_ADD1", op::translate_1to1_match_2_inputs<v1::Add> },
{ "GGML_OP_CONT", op::translate_cont }, {"GGML_OP_CONT", op::translate_cont },
{ "GGML_OP_CPY", op::translate_cpy }, {"GGML_OP_CPY", op::translate_cpy },
{ "GGML_OP_DIV", op::translate_1to1_match_2_inputs<v1::Divide> }, {"GGML_OP_DIV", op::translate_1to1_match_2_inputs<v1::Divide> },
{ "GGML_OP_GET_ROWS", op::translate_get_rows }, {"GGML_OP_GET_ROWS", op::translate_get_rows },
{ "GGML_OP_MUL", op::translate_1to1_match_2_inputs<v1::Multiply> }, {"GGML_OP_MUL", op::translate_1to1_match_2_inputs<v1::Multiply>},
{ "GGML_OP_MUL_MAT", op::translate_mulmat }, {"GGML_OP_MUL_MAT", op::translate_mulmat },
{ "GGML_OP_PERMUTE", op::translate_permute }, {"GGML_OP_PERMUTE", op::translate_permute },
{ "GGML_OP_RESHAPE", op::translate_reshape }, {"GGML_OP_RESHAPE", op::translate_reshape },
{ "GGML_OP_RMS_NORM", op::translate_rms_norm }, {"GGML_OP_RMS_NORM", op::translate_rms_norm },
{ "GGML_OP_ROPE", op::translate_rope }, {"GGML_OP_ROPE", op::translate_rope },
{ "GGML_OP_SCALE", op::translate_scale }, {"GGML_OP_SCALE", op::translate_scale },
{ "GGML_OP_SOFT_MAX", op::translate_soft_max }, {"GGML_OP_SOFT_MAX", op::translate_soft_max },
{ "GGML_OP_SUB", op::translate_1to1_match_2_inputs<v1::Subtract> }, {"GGML_OP_SUB", op::translate_1to1_match_2_inputs<v1::Subtract>},
{ "GGML_OP_TRANSPOSE", op::translate_transpose }, {"GGML_OP_TRANSPOSE", op::translate_transpose },
{ "GGML_UNARY_OP_SILU", op::translate_unary_silu }, {"GGML_UNARY_OP_SILU", op::translate_unary_silu },
{ "GGML_OP_VIEW", op::translate_view } {"GGML_OP_VIEW", op::translate_view },
{"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu },
}; };
} }

View File

@ -24,8 +24,8 @@ GGML_OP_CONVERTER(translate_scale);
GGML_OP_CONVERTER(translate_unary_silu); GGML_OP_CONVERTER(translate_unary_silu);
GGML_OP_CONVERTER(translate_soft_max); GGML_OP_CONVERTER(translate_soft_max);
GGML_OP_CONVERTER(translate_transpose); GGML_OP_CONVERTER(translate_transpose);
GGML_OP_CONVERTER(translate_unary);
GGML_OP_CONVERTER(translate_view); GGML_OP_CONVERTER(translate_view);
GGML_OP_CONVERTER(translate_glu_swiglu);
} // namespace op } // namespace op