diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e8548b053e..8e0836c0be 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -250,6 +250,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break; case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break; case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break; + case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 40cacb4652..4c192da650 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1043,6 +1043,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_XIELU: return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 62b028f4a4..e7433f2a65 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -127,6 +127,7 @@ #define OP_UNARY_NUM_CEIL 118 #define OP_UNARY_NUM_ROUND 119 #define OP_UNARY_NUM_TRUNC 120 +#define OP_UNARY_NUM_XIELU 121 #define OP_SUM_ROWS_NUM_SUM_ROWS 10 #define OP_SUM_ROWS_NUM_MEAN 11 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 846225d907..5b426be103 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -787,6 +787,13 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { args.max = ggml_get_op_params_f32(op, 1); } + if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) { + args.slope = ggml_get_op_params_f32(op, 1); // alpha_n + args.scale = ggml_get_op_params_f32(op, 2); // alpha_p + args.bias = ggml_get_op_params_f32(op, 3); // beta + args.val = ggml_get_op_params_f32(op, 4); // eps + } + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); if (pipeline.c4) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f67c5cd8a1..445a4deca8 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1177,6 +1177,15 @@ kernel void kernel_unary_impl( if (FC_OP == OP_UNARY_NUM_TRUNC) { dst_ptr[i0] = (T) trunc(x); } + + if (FC_OP == OP_UNARY_NUM_XIELU) { + const TC xi = x; + const TC gate = TC(xi > TC(0.0f)); + const TC clamped = fmin(xi, TC(args.val)); + const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi; + const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi; + dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg); + } } #undef FC_OP diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e54e6b2bb4..828a9c14a4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8506,6 +8506,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20481, 4, 1, 1 })); test_cases.emplace_back(new test_xielu()); + test_cases.emplace_back(new test_xielu(GGML_TYPE_F16)); + test_cases.emplace_back(new test_xielu(GGML_TYPE_F32, { 512, 16, 1, 1 })); + test_cases.emplace_back(new test_xielu(GGML_TYPE_F16, { 512, 16, 1, 1 })); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER)); test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG));