From 11c325c6e0666a30590cde390d5746a405e536b9 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Fri, 20 Feb 2026 01:18:30 +0900 Subject: [PATCH] ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. (#19700) * ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. * Fix to cast the src value to f32 before sin/cos computing. --- docs/ops.md | 8 ++-- docs/ops/WebGPU.csv | 40 ++++++++++++-------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 20 ++++++++++ ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 14 +++++++ 4 files changed, 62 insertions(+), 20 deletions(-) diff --git a/docs/ops.md b/docs/ops.md index 5754b0a96c..296c0ba1d4 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -31,7 +31,7 @@ Legend: | CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | -| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | +| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | | CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | @@ -96,13 +96,13 @@ Legend: | SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | -| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | +| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ | | SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | -| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ | -| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ | +| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | +| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | | STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | diff --git a/docs/ops/WebGPU.csv b/docs/ops/WebGPU.csv index 72ebea2cd7..e2ed3e2cfa 100644 --- a/docs/ops/WebGPU.csv +++ b/docs/ops/WebGPU.csv @@ -8760,22 +8760,14 @@ "WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=1","support","0","no","WebGPU" "WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=32","support","0","no","WebGPU" "WebGPU: WebGPU","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=129","support","0","no","WebGPU" -"WebGPU: WebGPU","SQR","type=f16,ne=[10,5,4,3]","support","0","no","WebGPU" -"WebGPU: WebGPU","SQRT","type=f16,ne=[10,3,3,2]","support","0","no","WebGPU" "WebGPU: WebGPU","LOG","type=f16,ne=[10,5,4,3]","support","1","yes","WebGPU" -"WebGPU: WebGPU","SIN","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU" -"WebGPU: WebGPU","COS","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU" "WebGPU: WebGPU","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU" "WebGPU: WebGPU","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU" "WebGPU: WebGPU","CEIL","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU" "WebGPU: WebGPU","ROUND","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU" "WebGPU: WebGPU","TRUNC","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU" -"WebGPU: WebGPU","SQR","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU" -"WebGPU: WebGPU","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU" "WebGPU: WebGPU","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU" -"WebGPU: WebGPU","SIN","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU" -"WebGPU: WebGPU","COS","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU" "WebGPU: WebGPU","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU" "WebGPU: WebGPU","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU" @@ -8786,22 +8778,14 @@ "WebGPU: WebGPU","ROUND","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU" "WebGPU: WebGPU","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU" "WebGPU: WebGPU","TRUNC","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU" -"WebGPU: WebGPU","SQR","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU" -"WebGPU: WebGPU","SQRT","type=f32,ne=[10,3,3,2]","support","0","no","WebGPU" "WebGPU: WebGPU","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU" -"WebGPU: WebGPU","SIN","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU" -"WebGPU: WebGPU","COS","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU" "WebGPU: WebGPU","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU" "WebGPU: WebGPU","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU" "WebGPU: WebGPU","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU" "WebGPU: WebGPU","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU" "WebGPU: WebGPU","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU" -"WebGPU: WebGPU","SQR","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU" -"WebGPU: WebGPU","SQRT","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU" "WebGPU: WebGPU","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU" -"WebGPU: WebGPU","SIN","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU" -"WebGPU: WebGPU","COS","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU" "WebGPU: WebGPU","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","WebGPU" "WebGPU: WebGPU","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU" @@ -18901,3 +18885,27 @@ "WebGPU: WebGPU","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU" "WebGPU: WebGPU","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","SQR","type=f16,ne=[10,5,4,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SQRT","type=f16,ne=[10,3,3,2]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SIN","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU" +"WebGPU: WebGPU","COS","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SQR","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SQR","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SQRT","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SQRT","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SIN","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SIN","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","COS","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","COS","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SIN","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU" +"WebGPU: WebGPU","COS","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SQR","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SQRT","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SIN","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","SIN","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","COS","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","COS","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU" diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b5fee48056..1c00d3cb2b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2008,6 +2008,14 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_LOG: return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SQR: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SQRT: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SIN: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_COS: + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: @@ -2967,6 +2975,18 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_LOG: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; + case GGML_OP_SQR: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_SQRT: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_SIN: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_COS: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; case GGML_OP_PAD: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index d639d98497..feaf6d0ac2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -170,6 +170,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { #ifdef TRUNC let res = trunc(src[params.offset_src + src_idx]); #endif +#ifdef SQR + let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx]; +#endif +#ifdef SQRT + let res = sqrt(src[params.offset_src + src_idx]); +#endif +#ifdef SIN + let res_f32 = sin(f32(src[params.offset_src + src_idx])); + let res = TYPE(res_f32); +#endif +#ifdef COS + let res_f32 = cos(f32(src[params.offset_src + src_idx])); + let res = TYPE(res_f32); +#endif #ifdef INPLACE src[params.offset_src + src_idx] = res;