ggml: add fused `relu_sqr` op
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
a98cac62c9
commit
fd990da706
|
|
@ -579,6 +579,7 @@ extern "C" {
|
||||||
GGML_UNARY_OP_TANH,
|
GGML_UNARY_OP_TANH,
|
||||||
GGML_UNARY_OP_ELU,
|
GGML_UNARY_OP_ELU,
|
||||||
GGML_UNARY_OP_RELU,
|
GGML_UNARY_OP_RELU,
|
||||||
|
GGML_UNARY_OP_RELU_SQR,
|
||||||
GGML_UNARY_OP_SIGMOID,
|
GGML_UNARY_OP_SIGMOID,
|
||||||
GGML_UNARY_OP_GELU,
|
GGML_UNARY_OP_GELU,
|
||||||
GGML_UNARY_OP_GELU_QUICK,
|
GGML_UNARY_OP_GELU_QUICK,
|
||||||
|
|
@ -1128,6 +1129,14 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_relu_sqr(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_relu_sqr_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_sigmoid(
|
GGML_API struct ggml_tensor * ggml_sigmoid(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
|
||||||
|
|
@ -2221,6 +2221,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
case GGML_UNARY_OP_ELU:
|
case GGML_UNARY_OP_ELU:
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
|
case GGML_UNARY_OP_RELU_SQR:
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
|
|
|
||||||
|
|
@ -9144,6 +9144,10 @@ void ggml_compute_forward_unary(
|
||||||
{
|
{
|
||||||
ggml_compute_forward_relu(params, dst);
|
ggml_compute_forward_relu(params, dst);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_UNARY_OP_RELU_SQR:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_relu_sqr(params, dst);
|
||||||
|
} break;
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_sigmoid(params, dst);
|
ggml_compute_forward_sigmoid(params, dst);
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,11 @@ static inline float op_relu(float x) {
|
||||||
return (x > 0.f) ? x : 0.f;
|
return (x > 0.f) ? x : 0.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline float op_relu_sqr(float x) {
|
||||||
|
float r = (x > 0.f) ? x : 0.f;
|
||||||
|
return r * r;
|
||||||
|
}
|
||||||
|
|
||||||
static inline float op_sigmoid(float x) {
|
static inline float op_sigmoid(float x) {
|
||||||
return 1.f / (1.f + expf(-x));
|
return 1.f / (1.f + expf(-x));
|
||||||
}
|
}
|
||||||
|
|
@ -262,6 +267,10 @@ void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor *
|
||||||
unary_op<op_relu>(params, dst);
|
unary_op<op_relu>(params, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_compute_forward_relu_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
|
unary_op<op_relu_sqr>(params, dst);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
|
void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||||
unary_op<op_sigmoid>(params, dst);
|
unary_op<op_sigmoid>(params, dst);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ void ggml_compute_forward_step(const struct ggml_compute_params * params, struct
|
||||||
void ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_relu_sqr(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
|
|
||||||
|
|
@ -2504,6 +2504,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
ggml_cuda_op_relu(ctx, dst);
|
ggml_cuda_op_relu(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_RELU_SQR:
|
||||||
|
ggml_cuda_op_relu_sqr(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
ggml_cuda_op_sigmoid(ctx, dst);
|
ggml_cuda_op_sigmoid(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
|
@ -4322,6 +4325,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
|
case GGML_UNARY_OP_RELU_SQR:
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
case GGML_UNARY_OP_HARDSIGMOID:
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
case GGML_UNARY_OP_HARDSWISH:
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,11 @@ static __device__ __forceinline__ float op_relu(float x) {
|
||||||
return fmaxf(x, 0);
|
return fmaxf(x, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float op_relu_sqr(float x) {
|
||||||
|
float r = fmaxf(x, 0);
|
||||||
|
return r * r;
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_sigmoid(float x) {
|
static __device__ __forceinline__ float op_sigmoid(float x) {
|
||||||
return 1.0f / (1.0f + expf(-x));
|
return 1.0f / (1.0f + expf(-x));
|
||||||
}
|
}
|
||||||
|
|
@ -186,6 +191,10 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_unary<op_relu>(ctx, dst);
|
ggml_cuda_op_unary<op_relu>(ctx, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary<op_relu_sqr>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_unary<op_sigmoid>(ctx, dst);
|
ggml_cuda_op_unary<op_sigmoid>(ctx, dst);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
|
||||||
|
|
@ -1170,6 +1170,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||||
"TANH",
|
"TANH",
|
||||||
"ELU",
|
"ELU",
|
||||||
"RELU",
|
"RELU",
|
||||||
|
"RELU_SQR",
|
||||||
"SIGMOID",
|
"SIGMOID",
|
||||||
"GELU",
|
"GELU",
|
||||||
"GELU_QUICK",
|
"GELU_QUICK",
|
||||||
|
|
@ -1187,7 +1188,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||||
"TRUNC",
|
"TRUNC",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_UNARY_OP_COUNT == 22, "GGML_UNARY_OP_COUNT != 22");
|
static_assert(GGML_UNARY_OP_COUNT == 23, "GGML_UNARY_OP_COUNT != 23");
|
||||||
|
|
||||||
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
||||||
"REGLU",
|
"REGLU",
|
||||||
|
|
@ -2645,6 +2646,20 @@ struct ggml_tensor * ggml_relu_inplace(
|
||||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
|
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_relu_sqr
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_relu_sqr(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_unary(ctx, a, GGML_UNARY_OP_RELU_SQR);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_relu_sqr_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU_SQR);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_leaky_relu
|
// ggml_leaky_relu
|
||||||
|
|
||||||
struct ggml_tensor * ggml_leaky_relu(
|
struct ggml_tensor * ggml_leaky_relu(
|
||||||
|
|
|
||||||
|
|
@ -919,10 +919,7 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||||
} break;
|
} break;
|
||||||
case LLM_FFN_RELU_SQR:
|
case LLM_FFN_RELU_SQR:
|
||||||
{
|
{
|
||||||
cur = ggml_relu(ctx0, cur);
|
cur = ggml_relu_sqr(ctx0, cur);
|
||||||
cb(cur, "ffn_relu", il);
|
|
||||||
|
|
||||||
cur = ggml_sqr(ctx0, cur);
|
|
||||||
cb(cur, "ffn_sqr(relu)", il);
|
cb(cur, "ffn_sqr(relu)", il);
|
||||||
} break;
|
} break;
|
||||||
case LLM_FFN_SWIGLU:
|
case LLM_FFN_SWIGLU:
|
||||||
|
|
@ -1225,8 +1222,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
// TODO: add support for gated squared relu
|
// TODO: add support for gated squared relu
|
||||||
GGML_ABORT("fatal error: gated squared relu not implemented");
|
GGML_ABORT("fatal error: gated squared relu not implemented");
|
||||||
} else {
|
} else {
|
||||||
cur = ggml_relu(ctx0, cur);
|
cur = ggml_relu_sqr(ctx0, cur);
|
||||||
cur = ggml_sqr(ctx0, cur);
|
|
||||||
cb(cur, "ffn_moe_relu_sqr", il);
|
cb(cur, "ffn_moe_relu_sqr", il);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue