From fd990da706395604cf85bf9ec501e3b66880dd96 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Sun, 11 Jan 2026 22:17:54 +0800 Subject: [PATCH] ggml: add fused `relu_sqr` op Signed-off-by: Molly Sophia --- ggml/include/ggml.h | 9 +++++++++ ggml/src/ggml-cpu/ggml-cpu.c | 1 + ggml/src/ggml-cpu/ops.cpp | 4 ++++ ggml/src/ggml-cpu/unary-ops.cpp | 9 +++++++++ ggml/src/ggml-cpu/unary-ops.h | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++++ ggml/src/ggml-cuda/unary.cu | 9 +++++++++ ggml/src/ggml-cuda/unary.cuh | 2 ++ ggml/src/ggml.c | 17 ++++++++++++++++- src/llama-graph.cpp | 8 ++------ 10 files changed, 57 insertions(+), 7 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f1366c4195..4a932e8e92 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -579,6 +579,7 @@ extern "C" { GGML_UNARY_OP_TANH, GGML_UNARY_OP_ELU, GGML_UNARY_OP_RELU, + GGML_UNARY_OP_RELU_SQR, GGML_UNARY_OP_SIGMOID, GGML_UNARY_OP_GELU, GGML_UNARY_OP_GELU_QUICK, @@ -1128,6 +1129,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_relu_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_relu_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sigmoid( struct ggml_context * ctx, struct ggml_tensor * a); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0f1f8fffa8..b1c86a8f04 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2221,6 +2221,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_RELU_SQR: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_HARDSIGMOID: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 43e67c7131..0ed9043eae 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9144,6 +9144,10 @@ void ggml_compute_forward_unary( { ggml_compute_forward_relu(params, dst); } break; + case GGML_UNARY_OP_RELU_SQR: + { + ggml_compute_forward_relu_sqr(params, dst); + } break; case GGML_UNARY_OP_SIGMOID: { ggml_compute_forward_sigmoid(params, dst); diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 1d9873ad0f..505b24ce2e 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -28,6 +28,11 @@ static inline float op_relu(float x) { return (x > 0.f) ? x : 0.f; } +static inline float op_relu_sqr(float x) { + float r = (x > 0.f) ? x : 0.f; + return r * r; +} + static inline float op_sigmoid(float x) { return 1.f / (1.f + expf(-x)); } @@ -262,6 +267,10 @@ void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * unary_op(params, dst); } +void ggml_compute_forward_relu_sqr(const ggml_compute_params * params, ggml_tensor * dst) { + unary_op(params, dst); +} + void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) { unary_op(params, dst); } diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h index bcad5a3af1..0229950804 100644 --- a/ggml/src/ggml-cpu/unary-ops.h +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -13,6 +13,7 @@ void ggml_compute_forward_step(const struct ggml_compute_params * params, struct void ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_relu_sqr(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 456ca40e3d..32468988d3 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2504,6 +2504,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_RELU: ggml_cuda_op_relu(ctx, dst); break; + case GGML_UNARY_OP_RELU_SQR: + ggml_cuda_op_relu_sqr(ctx, dst); + break; case GGML_UNARY_OP_SIGMOID: ggml_cuda_op_sigmoid(ctx, dst); break; @@ -4322,6 +4325,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_RELU_SQR: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index d4866067a4..560fade72d 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -45,6 +45,11 @@ static __device__ __forceinline__ float op_relu(float x) { return fmaxf(x, 0); } +static __device__ __forceinline__ float op_relu_sqr(float x) { + float r = fmaxf(x, 0); + return r * r; +} + static __device__ __forceinline__ float op_sigmoid(float x) { return 1.0f / (1.0f + expf(-x)); } @@ -186,6 +191,10 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 609046e569..769341ac44 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -41,6 +41,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 990f3ffa0e..ad4c7bd207 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1170,6 +1170,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "TANH", "ELU", "RELU", + "RELU_SQR", "SIGMOID", "GELU", "GELU_QUICK", @@ -1187,7 +1188,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "TRUNC", }; -static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22"); +static_assert(GGML_UNARY_OP_COUNT == 23, "GGML_UNARY_OP_COUNT != 23"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2645,6 +2646,20 @@ struct ggml_tensor * ggml_relu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU); } +// ggml_relu_sqr + +struct ggml_tensor * ggml_relu_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_RELU_SQR); +} + +struct ggml_tensor * ggml_relu_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU_SQR); +} + // ggml_leaky_relu struct ggml_tensor * ggml_leaky_relu( diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 374ff1ebf3..9cc80e5bb4 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -919,10 +919,7 @@ ggml_tensor * llm_graph_context::build_ffn( } break; case LLM_FFN_RELU_SQR: { - cur = ggml_relu(ctx0, cur); - cb(cur, "ffn_relu", il); - - cur = ggml_sqr(ctx0, cur); + cur = ggml_relu_sqr(ctx0, cur); cb(cur, "ffn_sqr(relu)", il); } break; case LLM_FFN_SWIGLU: @@ -1225,8 +1222,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // TODO: add support for gated squared relu GGML_ABORT("fatal error: gated squared relu not implemented"); } else { - cur = ggml_relu(ctx0, cur); - cur = ggml_sqr(ctx0, cur); + cur = ggml_relu_sqr(ctx0, cur); cb(cur, "ffn_moe_relu_sqr", il); } break; default: