metal : add XIELU unary op (#20802)
This commit is contained in:
parent
be76dd0bb2
commit
aa0f1897b7
|
|
@ -250,6 +250,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal
|
|||
case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break;
|
||||
case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break;
|
||||
case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break;
|
||||
case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
} break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
|
|
|
|||
|
|
@ -1043,6 +1043,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
case GGML_UNARY_OP_XIELU:
|
||||
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
|
||||
default:
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@
|
|||
#define OP_UNARY_NUM_CEIL 118
|
||||
#define OP_UNARY_NUM_ROUND 119
|
||||
#define OP_UNARY_NUM_TRUNC 120
|
||||
#define OP_UNARY_NUM_XIELU 121
|
||||
|
||||
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
|
||||
#define OP_SUM_ROWS_NUM_MEAN 11
|
||||
|
|
|
|||
|
|
@ -787,6 +787,13 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
|||
args.max = ggml_get_op_params_f32(op, 1);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) {
|
||||
args.slope = ggml_get_op_params_f32(op, 1); // alpha_n
|
||||
args.scale = ggml_get_op_params_f32(op, 2); // alpha_p
|
||||
args.bias = ggml_get_op_params_f32(op, 3); // beta
|
||||
args.val = ggml_get_op_params_f32(op, 4); // eps
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
if (pipeline.c4) {
|
||||
|
|
|
|||
|
|
@ -1177,6 +1177,15 @@ kernel void kernel_unary_impl(
|
|||
if (FC_OP == OP_UNARY_NUM_TRUNC) {
|
||||
dst_ptr[i0] = (T) trunc(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_XIELU) {
|
||||
const TC xi = x;
|
||||
const TC gate = TC(xi > TC(0.0f));
|
||||
const TC clamped = fmin(xi, TC(args.val));
|
||||
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
|
||||
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
|
||||
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
|
|
|
|||
|
|
@ -8506,6 +8506,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20481, 4, 1, 1 }));
|
||||
|
||||
test_cases.emplace_back(new test_xielu());
|
||||
test_cases.emplace_back(new test_xielu(GGML_TYPE_F16));
|
||||
test_cases.emplace_back(new test_xielu(GGML_TYPE_F32, { 512, 16, 1, 1 }));
|
||||
test_cases.emplace_back(new test_xielu(GGML_TYPE_F16, { 512, 16, 1, 1 }));
|
||||
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER));
|
||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER_DIAG));
|
||||
|
|
|
|||
Loading…
Reference in New Issue