Add GeGLU
This commit is contained in:
parent
be07073e0e
commit
597561242f
|
|
@ -249,17 +249,30 @@ static bool is_op_unsupported_case(const ggml_tensor* op) {
|
|||
const auto* op_params = op->op_params;
|
||||
memcpy(&scale, (const float*) op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float*) op_params + 1, sizeof(float));
|
||||
const uint32_t h = op->src[0]->ne[2];
|
||||
const uint32_t n_head = op->src[0]->ne[0];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
if (max_bias > 0) {
|
||||
GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with max_bias > 0\n");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
const float slope =
|
||||
(max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
|
||||
|
||||
if (slope != 1.0f) {
|
||||
GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with slope != 1.0f\n");
|
||||
if (op->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
if (op->src[4] != nullptr) {
|
||||
GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n");
|
||||
return true;
|
||||
}
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
const auto* op_params = op->op_params;
|
||||
memcpy(&scale, (const float*) op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float*) op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (const float*) op_params + 2, sizeof(float));
|
||||
if (max_bias > 0) {
|
||||
GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\n");
|
||||
return true;
|
||||
}
|
||||
if (logit_softcap != 0) {
|
||||
GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\n");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
@ -357,7 +370,8 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
|
|||
GGML_OP_ROPE,
|
||||
GGML_OP_RMS_NORM,
|
||||
GGML_OP_SCALE,
|
||||
GGML_OP_SOFT_MAX,
|
||||
// softmax is not updated due to replaced by flash_attn_ext
|
||||
// GGML_OP_SOFT_MAX,
|
||||
GGML_OP_SET_ROWS,
|
||||
GGML_OP_FLASH_ATTN_EXT,
|
||||
GGML_OP_CPY};
|
||||
|
|
@ -366,6 +380,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
|
|||
};
|
||||
static const std::set<ggml_glu_op> supported_glu_ops{
|
||||
GGML_GLU_OP_SWIGLU,
|
||||
GGML_GLU_OP_GEGLU,
|
||||
};
|
||||
|
||||
switch (op->op) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
#include <memory>
|
||||
#include <openvino/core/node_output.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/gelu.hpp>
|
||||
#include <openvino/op/multiply.hpp>
|
||||
#include <openvino/op/sigmoid.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/split.hpp>
|
||||
|
||||
#include "../node_context.hpp"
|
||||
#include "../op_table.hpp"
|
||||
#include "../utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace ggml {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_glu_geglu(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
|
||||
ov::Output<ov::Node> src0;
|
||||
ov::Output<ov::Node> src1;
|
||||
if (context.get_input_size() == 2) {
|
||||
src0 = context.get_input(0);
|
||||
src1 = context.get_input(1);
|
||||
} else {
|
||||
auto combined = context.get_input(0);
|
||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {2});
|
||||
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
|
||||
src0 = split->output(0);
|
||||
src1 = split->output(1);
|
||||
}
|
||||
|
||||
int32_t* params = context.get_output_op_params(0);
|
||||
const int32_t swapped = params[1];
|
||||
if (swapped) {
|
||||
std::swap(src0, src1);
|
||||
}
|
||||
|
||||
auto gelu = std::make_shared<ov::op::v7::Gelu>(src0);
|
||||
auto res = std::make_shared<ov::op::v1::Multiply>(gelu, src1);
|
||||
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace ggml
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
|
|
@ -31,6 +31,13 @@ OutputVector translate_glu_swiglu(const NodeContext& context) {
|
|||
src0 = split->output(0);
|
||||
src1 = split->output(1);
|
||||
}
|
||||
|
||||
int32_t* params = context.get_output_op_params(0);
|
||||
const int32_t swapped = params[1];
|
||||
if (swapped) {
|
||||
std::swap(src0, src1);
|
||||
}
|
||||
|
||||
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src0);
|
||||
auto silu = std::make_shared<ov::op::v1::Multiply>(src0, sigmoid);
|
||||
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src1);
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ std::unordered_map<std::string, CreatorFunction> get_supported_ops() {
|
|||
{"GGML_UNARY_OP_SILU", op::translate_unary_silu },
|
||||
{"GGML_OP_VIEW", op::translate_view },
|
||||
{"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu },
|
||||
{"GGML_GLU_OP_GEGLU", op::translate_glu_geglu },
|
||||
{"GGML_OP_SET_ROWS", op::translate_set_rows },
|
||||
{"GGML_OP_CPY", op::translate_cpy },
|
||||
{"GGML_OP_FLASH_ATTN_EXT", op::translate_flash_attn_ext },
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ GGML_OP_CONVERTER(translate_soft_max);
|
|||
GGML_OP_CONVERTER(translate_transpose);
|
||||
GGML_OP_CONVERTER(translate_view);
|
||||
GGML_OP_CONVERTER(translate_glu_swiglu);
|
||||
GGML_OP_CONVERTER(translate_glu_geglu);
|
||||
GGML_OP_CONVERTER(translate_set_rows);
|
||||
GGML_OP_CONVERTER(translate_cpy);
|
||||
GGML_OP_CONVERTER(translate_flash_attn_ext);
|
||||
|
|
|
|||
Loading…
Reference in New Issue