From 541bf376229d6c5cb61e4eb59b3c7629755b8271 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 5 Mar 2026 04:19:00 +0900 Subject: [PATCH] Add concat op to webgpu. (#20068) --- docs/ops.md | 2 +- docs/ops/WebGPU.csv | 64 ++++++++-------- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 55 ++++++++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 67 +++++++++++++++++ ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl | 75 +++++++++++++++++++ 5 files changed, 230 insertions(+), 33 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl diff --git a/docs/ops.md b/docs/ops.md index 296c0ba1d4..8213bc6abf 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -24,7 +24,7 @@ Legend: | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | | CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | -| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | +| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | | CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | | CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | diff --git a/docs/ops/WebGPU.csv b/docs/ops/WebGPU.csv index e2ed3e2cfa..9e081e7605 100644 --- a/docs/ops/WebGPU.csv +++ b/docs/ops/WebGPU.csv @@ -9535,38 +9535,38 @@ "WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=0,v=0,inplace=1","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","WebGPU" "WebGPU: WebGPU","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","WebGPU" diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 369475eaf5..17c5e0fb51 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -173,6 +173,22 @@ struct ggml_webgpu_scale_pipeline_key_hash { } }; +/** Concat **/ + +struct ggml_webgpu_concat_pipeline_key { + int type; + + bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; } +}; + +struct ggml_webgpu_concat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + return seed; + } +}; + /** Binary **/ struct ggml_webgpu_binary_pipeline_key { @@ -403,6 +419,8 @@ class ggml_webgpu_shader_lib { pad_pipelines; // circular/non-circular std::unordered_map binary_pipelines; // type/op/inplace/overlap + std::unordered_map + concat_pipelines; // type std::unordered_map flash_attn_pipelines; std::unordered_maptype, + }; + + auto it = concat_pipelines.find(key); + if (it != concat_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "concat"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported type for concat shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_concat, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + concat_pipelines[key] = pipeline; + return concat_pipelines[key]; + } + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 19451618ec..334919e589 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1484,6 +1484,68 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } +static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); + uint32_t dim = (uint32_t) dst->op_params[0]; + + std::vector params = { + ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + dim, + (uint32_t)src0->ne[dim] + }; + + std::vector entries = { + { + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) + }, + { + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) + }, + { + .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) + } + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); @@ -2068,6 +2130,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_MUL: case GGML_OP_DIV: return ggml_webgpu_binary_op(ctx, src0, src1, node); + case GGML_OP_CONCAT: + return ggml_webgpu_concat(ctx, src0, src1, node); case GGML_OP_RMS_NORM: return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2894,6 +2958,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && (src1->type == op->type); break; + case GGML_OP_CONCAT: + supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); + break; case GGML_OP_CPY: case GGML_OP_CONT: supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl new file mode 100644 index 0000000000..a22d245d2c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl @@ -0,0 +1,75 @@ +struct Params { + ne: u32, + + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src0_0: u32, + stride_src0_1: u32, + stride_src0_2: u32, + stride_src0_3: u32, + + stride_src1_0: u32, + stride_src1_1: u32, + stride_src1_2: u32, + stride_src1_3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + dim: u32, + src0_nedim: u32 +}; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_I32 +#define DataType i32 +#endif + +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1 : array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + + if (gid.x < params.ne) { + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + var ni = array(i0, i1, i2, i3); + + if (ni[params.dim] < params.src0_nedim) { + let src_i = ni[0] * params.stride_src0_0 + + ni[1] * params.stride_src0_1 + + ni[2] * params.stride_src0_2 + + ni[3] * params.stride_src0_3; + dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i]; + } else { + ni[params.dim] -= params.src0_nedim; + let src_i = ni[0] * params.stride_src1_0 + + ni[1] * params.stride_src1_1 + + ni[2] * params.stride_src1_2 + + ni[3] * params.stride_src1_3; + dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i]; + } + } +}