From 597561242f54da7913509004f059b085f08618a5 Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Mon, 15 Sep 2025 11:13:59 +0800 Subject: [PATCH] Add GeGLU --- ggml/src/ggml-openvino/ggml-openvino.cpp | 37 ++++++++++---- .../ggml-openvino/openvino/op/glu_geglu.cpp | 50 +++++++++++++++++++ .../ggml-openvino/openvino/op/glu_swiglu.cpp | 7 +++ ggml/src/ggml-openvino/openvino/op_table.cpp | 1 + ggml/src/ggml-openvino/openvino/op_table.hpp | 1 + 5 files changed, 85 insertions(+), 11 deletions(-) create mode 100644 ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 60a2eb388e..6da653716f 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -249,17 +249,30 @@ static bool is_op_unsupported_case(const ggml_tensor* op) { const auto* op_params = op->op_params; memcpy(&scale, (const float*) op_params + 0, sizeof(float)); memcpy(&max_bias, (const float*) op_params + 1, sizeof(float)); - const uint32_t h = op->src[0]->ne[2]; - const uint32_t n_head = op->src[0]->ne[0]; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + if (max_bias > 0) { + GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with max_bias > 0\n"); + return true; + } + } - const float m0 = powf(2.0f, -(max_bias) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const float slope = - (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; - - if (slope != 1.0f) { - GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with slope != 1.0f\n"); + if (op->op == GGML_OP_FLASH_ATTN_EXT) { + if (op->src[4] != nullptr) { + GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n"); + return true; + } + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + const auto* op_params = op->op_params; + memcpy(&scale, (const float*) op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float*) op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float*) op_params + 2, sizeof(float)); + if (max_bias > 0) { + GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\n"); + return true; + } + if (logit_softcap != 0) { + GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\n"); return true; } } @@ -357,7 +370,8 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con GGML_OP_ROPE, GGML_OP_RMS_NORM, GGML_OP_SCALE, - GGML_OP_SOFT_MAX, + // softmax is not updated due to replaced by flash_attn_ext + // GGML_OP_SOFT_MAX, GGML_OP_SET_ROWS, GGML_OP_FLASH_ATTN_EXT, GGML_OP_CPY}; @@ -366,6 +380,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con }; static const std::set supported_glu_ops{ GGML_GLU_OP_SWIGLU, + GGML_GLU_OP_GEGLU, }; switch (op->op) { diff --git a/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp new file mode 100644 index 0000000000..4295bf7517 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp @@ -0,0 +1,50 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../op_table.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_glu_geglu(const NodeContext& context) { + num_inputs_check(context, 1, 2); + + ov::Output src0; + ov::Output src1; + if (context.get_input_size() == 2) { + src0 = context.get_input(0); + src1 = context.get_input(1); + } else { + auto combined = context.get_input(0); + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {2}); + auto split = std::make_shared(combined, split_axis, 2); + src0 = split->output(0); + src1 = split->output(1); + } + + int32_t* params = context.get_output_op_params(0); + const int32_t swapped = params[1]; + if (swapped) { + std::swap(src0, src1); + } + + auto gelu = std::make_shared(src0); + auto res = std::make_shared(gelu, src1); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp index 138ef65090..bef42fe4b7 100644 --- a/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +++ b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp @@ -31,6 +31,13 @@ OutputVector translate_glu_swiglu(const NodeContext& context) { src0 = split->output(0); src1 = split->output(1); } + + int32_t* params = context.get_output_op_params(0); + const int32_t swapped = params[1]; + if (swapped) { + std::swap(src0, src1); + } + auto sigmoid = std::make_shared(src0); auto silu = std::make_shared(src0, sigmoid); auto res = std::make_shared(silu, src1); diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp index ee55f84b96..e36e8f17cc 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.cpp +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -34,6 +34,7 @@ std::unordered_map get_supported_ops() { {"GGML_UNARY_OP_SILU", op::translate_unary_silu }, {"GGML_OP_VIEW", op::translate_view }, {"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu }, + {"GGML_GLU_OP_GEGLU", op::translate_glu_geglu }, {"GGML_OP_SET_ROWS", op::translate_set_rows }, {"GGML_OP_CPY", op::translate_cpy }, {"GGML_OP_FLASH_ATTN_EXT", op::translate_flash_attn_ext }, diff --git a/ggml/src/ggml-openvino/openvino/op_table.hpp b/ggml/src/ggml-openvino/openvino/op_table.hpp index faa61f5f6c..5d4f053860 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.hpp +++ b/ggml/src/ggml-openvino/openvino/op_table.hpp @@ -25,6 +25,7 @@ GGML_OP_CONVERTER(translate_soft_max); GGML_OP_CONVERTER(translate_transpose); GGML_OP_CONVERTER(translate_view); GGML_OP_CONVERTER(translate_glu_swiglu); +GGML_OP_CONVERTER(translate_glu_geglu); GGML_OP_CONVERTER(translate_set_rows); GGML_OP_CONVERTER(translate_cpy); GGML_OP_CONVERTER(translate_flash_attn_ext);