Add concat op to webgpu. (#20068)

This commit is contained in:
Masashi Yoshimura 2026-03-05 04:19:00 +09:00 committed by GitHub
parent d969e933e1
commit 541bf37622
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 230 additions and 33 deletions

View File

@ -24,7 +24,7 @@ Legend:
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |

View File

@ -9535,38 +9535,38 @@
"WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=0,v=0,inplace=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","WebGPU"
"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","WebGPU"
"WebGPU: WebGPU","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","WebGPU"

Can't render this file because it is too large.

View File

@ -173,6 +173,22 @@ struct ggml_webgpu_scale_pipeline_key_hash {
}
};
/** Concat **/
struct ggml_webgpu_concat_pipeline_key {
int type;
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
};
struct ggml_webgpu_concat_pipeline_key_hash {
size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
return seed;
}
};
/** Binary **/
struct ggml_webgpu_binary_pipeline_key {
@ -403,6 +419,8 @@ class ggml_webgpu_shader_lib {
pad_pipelines; // circular/non-circular
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
binary_pipelines; // type/op/inplace/overlap
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
concat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
@ -1096,6 +1114,43 @@ class ggml_webgpu_shader_lib {
return binary_pipelines[key];
}
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_concat_pipeline_key key = {
.type = context.dst->type,
};
auto it = concat_pipelines.find(key);
if (it != concat_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "concat";
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_I32:
defines.push_back("TYPE_I32");
variant += "_i32";
break;
default:
GGML_ABORT("Unsupported type for concat shader");
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_concat, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
concat_pipelines[key] = pipeline;
return concat_pipelines[key];
}
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool has_mask = context.src3 != nullptr;
const bool has_sinks = context.src4 != nullptr;

View File

@ -1484,6 +1484,68 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
uint32_t dim = (uint32_t) dst->op_params[0];
std::vector<uint32_t> params = {
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
dim,
(uint32_t)src0->ne[dim]
};
std::vector<wgpu::BindGroupEntry> entries = {
{
.binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0)
},
{
.binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1)
},
{
.binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst)
}
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
int inplace = ggml_webgpu_tensor_equal(src, dst);
@ -2068,6 +2130,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_MUL:
case GGML_OP_DIV:
return ggml_webgpu_binary_op(ctx, src0, src1, node);
case GGML_OP_CONCAT:
return ggml_webgpu_concat(ctx, src0, src1, node);
case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node);
case GGML_OP_ROPE:
@ -2894,6 +2958,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
(src1->type == op->type);
break;
case GGML_OP_CONCAT:
supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
break;
case GGML_OP_CPY:
case GGML_OP_CONT:
supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&

View File

@ -0,0 +1,75 @@
struct Params {
ne: u32,
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src0_0: u32,
stride_src0_1: u32,
stride_src0_2: u32,
stride_src0_3: u32,
stride_src1_0: u32,
stride_src1_1: u32,
stride_src1_2: u32,
stride_src1_3: u32,
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
dim: u32,
src0_nedim: u32
};
#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_I32
#define DataType i32
#endif
@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> src1 : array<DataType>;
@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
var ni = array<u32, 4>(i0, i1, i2, i3);
if (ni[params.dim] < params.src0_nedim) {
let src_i = ni[0] * params.stride_src0_0 +
ni[1] * params.stride_src0_1 +
ni[2] * params.stride_src0_2 +
ni[3] * params.stride_src0_3;
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i];
} else {
ni[params.dim] -= params.src0_nedim;
let src_i = ni[0] * params.stride_src1_0 +
ni[1] * params.stride_src1_1 +
ni[2] * params.stride_src1_2 +
ni[3] * params.stride_src1_3;
dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i];
}
}
}