ggml: add fused `relu_sqr` op

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2026-01-11 22:17:54 +08:00
parent a98cac62c9
commit fd990da706
10 changed files with 57 additions and 7 deletions

View File

@ -579,6 +579,7 @@ extern "C" {
GGML_UNARY_OP_TANH,
GGML_UNARY_OP_ELU,
GGML_UNARY_OP_RELU,
GGML_UNARY_OP_RELU_SQR,
GGML_UNARY_OP_SIGMOID,
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK,
@ -1128,6 +1129,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_relu_sqr(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_relu_sqr_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sigmoid(
struct ggml_context * ctx,
struct ggml_tensor * a);

View File

@ -2221,6 +2221,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_RELU_SQR:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:

View File

@ -9144,6 +9144,10 @@ void ggml_compute_forward_unary(
{
ggml_compute_forward_relu(params, dst);
} break;
case GGML_UNARY_OP_RELU_SQR:
{
ggml_compute_forward_relu_sqr(params, dst);
} break;
case GGML_UNARY_OP_SIGMOID:
{
ggml_compute_forward_sigmoid(params, dst);

View File

@ -28,6 +28,11 @@ static inline float op_relu(float x) {
return (x > 0.f) ? x : 0.f;
}
static inline float op_relu_sqr(float x) {
float r = (x > 0.f) ? x : 0.f;
return r * r;
}
static inline float op_sigmoid(float x) {
return 1.f / (1.f + expf(-x));
}
@ -262,6 +267,10 @@ void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor *
unary_op<op_relu>(params, dst);
}
void ggml_compute_forward_relu_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_relu_sqr>(params, dst);
}
void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_sigmoid>(params, dst);
}

View File

@ -13,6 +13,7 @@ void ggml_compute_forward_step(const struct ggml_compute_params * params, struct
void ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_relu_sqr(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -2504,6 +2504,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_RELU:
ggml_cuda_op_relu(ctx, dst);
break;
case GGML_UNARY_OP_RELU_SQR:
ggml_cuda_op_relu_sqr(ctx, dst);
break;
case GGML_UNARY_OP_SIGMOID:
ggml_cuda_op_sigmoid(ctx, dst);
break;
@ -4322,6 +4325,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_RELU_SQR:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:

View File

@ -45,6 +45,11 @@ static __device__ __forceinline__ float op_relu(float x) {
return fmaxf(x, 0);
}
static __device__ __forceinline__ float op_relu_sqr(float x) {
float r = fmaxf(x, 0);
return r * r;
}
static __device__ __forceinline__ float op_sigmoid(float x) {
return 1.0f / (1.0f + expf(-x));
}
@ -186,6 +191,10 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_relu>(ctx, dst);
}
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_relu_sqr>(ctx, dst);
}
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_sigmoid>(ctx, dst);
}

View File

@ -41,6 +41,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -1170,6 +1170,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"TANH",
"ELU",
"RELU",
"RELU_SQR",
"SIGMOID",
"GELU",
"GELU_QUICK",
@ -1187,7 +1188,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"TRUNC",
};
static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22");
static_assert(GGML_UNARY_OP_COUNT == 23, "GGML_UNARY_OP_COUNT != 23");
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU",
@ -2645,6 +2646,20 @@ struct ggml_tensor * ggml_relu_inplace(
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
}
// ggml_relu_sqr
struct ggml_tensor * ggml_relu_sqr(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_RELU_SQR);
}
struct ggml_tensor * ggml_relu_sqr_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU_SQR);
}
// ggml_leaky_relu
struct ggml_tensor * ggml_leaky_relu(

View File

@ -919,10 +919,7 @@ ggml_tensor * llm_graph_context::build_ffn(
} break;
case LLM_FFN_RELU_SQR:
{
cur = ggml_relu(ctx0, cur);
cb(cur, "ffn_relu", il);
cur = ggml_sqr(ctx0, cur);
cur = ggml_relu_sqr(ctx0, cur);
cb(cur, "ffn_sqr(relu)", il);
} break;
case LLM_FFN_SWIGLU:
@ -1225,8 +1222,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
// TODO: add support for gated squared relu
GGML_ABORT("fatal error: gated squared relu not implemented");
} else {
cur = ggml_relu(ctx0, cur);
cur = ggml_sqr(ctx0, cur);
cur = ggml_relu_sqr(ctx0, cur);
cb(cur, "ffn_moe_relu_sqr", il);
} break;
default: