Add GeGLU

This commit is contained in:
Yu, Zijun 2025-09-15 11:13:59 +08:00 committed by Mustafa Cavus
parent be07073e0e
commit 597561242f
5 changed files with 85 additions and 11 deletions

View File

@ -249,17 +249,30 @@ static bool is_op_unsupported_case(const ggml_tensor* op) {
const auto* op_params = op->op_params;
memcpy(&scale, (const float*) op_params + 0, sizeof(float));
memcpy(&max_bias, (const float*) op_params + 1, sizeof(float));
const uint32_t h = op->src[0]->ne[2];
const uint32_t n_head = op->src[0]->ne[0];
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
if (max_bias > 0) {
GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with max_bias > 0\n");
return true;
}
}
const float m0 = powf(2.0f, -(max_bias) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const float slope =
(max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
if (slope != 1.0f) {
GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with slope != 1.0f\n");
if (op->op == GGML_OP_FLASH_ATTN_EXT) {
if (op->src[4] != nullptr) {
GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n");
return true;
}
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
const auto* op_params = op->op_params;
memcpy(&scale, (const float*) op_params + 0, sizeof(float));
memcpy(&max_bias, (const float*) op_params + 1, sizeof(float));
memcpy(&logit_softcap, (const float*) op_params + 2, sizeof(float));
if (max_bias > 0) {
GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\n");
return true;
}
if (logit_softcap != 0) {
GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\n");
return true;
}
}
@ -357,7 +370,8 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
GGML_OP_ROPE,
GGML_OP_RMS_NORM,
GGML_OP_SCALE,
GGML_OP_SOFT_MAX,
// softmax is not updated due to replaced by flash_attn_ext
// GGML_OP_SOFT_MAX,
GGML_OP_SET_ROWS,
GGML_OP_FLASH_ATTN_EXT,
GGML_OP_CPY};
@ -366,6 +380,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
};
static const std::set<ggml_glu_op> supported_glu_ops{
GGML_GLU_OP_SWIGLU,
GGML_GLU_OP_GEGLU,
};
switch (op->op) {

View File

@ -0,0 +1,50 @@
#include <memory>
#include <openvino/core/node_output.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/gelu.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/sigmoid.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/split.hpp>
#include "../node_context.hpp"
#include "../op_table.hpp"
#include "../utils.hpp"
namespace ov {
namespace frontend {
namespace ggml {
namespace op {
OutputVector translate_glu_geglu(const NodeContext& context) {
num_inputs_check(context, 1, 2);
ov::Output<ov::Node> src0;
ov::Output<ov::Node> src1;
if (context.get_input_size() == 2) {
src0 = context.get_input(0);
src1 = context.get_input(1);
} else {
auto combined = context.get_input(0);
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {2});
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
src0 = split->output(0);
src1 = split->output(1);
}
int32_t* params = context.get_output_op_params(0);
const int32_t swapped = params[1];
if (swapped) {
std::swap(src0, src1);
}
auto gelu = std::make_shared<ov::op::v7::Gelu>(src0);
auto res = std::make_shared<ov::op::v1::Multiply>(gelu, src1);
return rename_outputs_with_suffix({res}, context.get_name());
}
} // namespace op
} // namespace ggml
} // namespace frontend
} // namespace ov

View File

@ -31,6 +31,13 @@ OutputVector translate_glu_swiglu(const NodeContext& context) {
src0 = split->output(0);
src1 = split->output(1);
}
int32_t* params = context.get_output_op_params(0);
const int32_t swapped = params[1];
if (swapped) {
std::swap(src0, src1);
}
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src0);
auto silu = std::make_shared<ov::op::v1::Multiply>(src0, sigmoid);
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src1);

View File

@ -34,6 +34,7 @@ std::unordered_map<std::string, CreatorFunction> get_supported_ops() {
{"GGML_UNARY_OP_SILU", op::translate_unary_silu },
{"GGML_OP_VIEW", op::translate_view },
{"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu },
{"GGML_GLU_OP_GEGLU", op::translate_glu_geglu },
{"GGML_OP_SET_ROWS", op::translate_set_rows },
{"GGML_OP_CPY", op::translate_cpy },
{"GGML_OP_FLASH_ATTN_EXT", op::translate_flash_attn_ext },

View File

@ -25,6 +25,7 @@ GGML_OP_CONVERTER(translate_soft_max);
GGML_OP_CONVERTER(translate_transpose);
GGML_OP_CONVERTER(translate_view);
GGML_OP_CONVERTER(translate_glu_swiglu);
GGML_OP_CONVERTER(translate_glu_geglu);
GGML_OP_CONVERTER(translate_set_rows);
GGML_OP_CONVERTER(translate_cpy);
GGML_OP_CONVERTER(translate_flash_attn_ext);