diff --git a/docs/ops.md b/docs/ops.md index 37329d56a8..f914c2b7d2 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -80,7 +80,7 @@ Legend: | POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | -| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ | ❌ | ❌ | +| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | diff --git a/docs/ops/WebGPU.csv b/docs/ops/WebGPU.csv index 9e081e7605..b7761b9dd3 100644 --- a/docs/ops/WebGPU.csv +++ b/docs/ops/WebGPU.csv @@ -5023,20 +5023,20 @@ "WebGPU: WebGPU","ARGMAX","type=f32,ne=[1024,12,1,1]","support","1","yes","WebGPU" "WebGPU: WebGPU","ARGMAX","type=f32,ne=[2000,10,1,1]","support","1","yes","WebGPU" "WebGPU: WebGPU","ARGMAX","type=f32,ne=[5438,3,1,1]","support","1","yes","WebGPU" -"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","0","no","WebGPU" +"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,2,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,2,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,1],nr=[2,1,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,1],nr=[1,1,1,2]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,2,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,2,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=f32,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=i32,ne=[10,5,4,3],nr=[2,1,1,1]","support","1","yes","WebGPU" +"WebGPU: WebGPU","REPEAT","type=i16,ne=[10,5,4,3],nr=[1,1,1,2]","support","1","yes","WebGPU" "WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,1,1,1],v=0","support","0","no","WebGPU" "WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[2,1,1,1],v=0","support","0","no","WebGPU" "WebGPU: WebGPU","REPEAT_BACK","type=f32,ne=[8,6,4,2],nr=[1,2,1,1],v=0","support","0","no","WebGPU" diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3c38b1a230..3d7e59fddf 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -198,6 +198,22 @@ struct ggml_webgpu_concat_pipeline_key_hash { } }; +/** Repeat **/ + +struct ggml_webgpu_repeat_pipeline_key { + int type; + + bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; } +}; + +struct ggml_webgpu_repeat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + return seed; + } +}; + /** Binary **/ struct ggml_webgpu_binary_pipeline_key { @@ -431,6 +447,8 @@ class ggml_webgpu_shader_lib { binary_pipelines; // type/op/inplace/overlap std::unordered_map concat_pipelines; // type + std::unordered_map + repeat_pipelines; // type std::unordered_map flash_attn_pipelines; std::unordered_map defines; - std::string variant = "concat"; + std::string variant = "concat"; switch (key.type) { case GGML_TYPE_F32: @@ -1164,15 +1182,56 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_concat, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; + auto processed = preprocessor.preprocess(wgsl_concat, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - concat_pipelines[key] = pipeline; + pipeline.context = decisions; + concat_pipelines[key] = pipeline; return concat_pipelines[key]; } + webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_repeat_pipeline_key key = { + .type = context.dst->type, + }; + + auto it = repeat_pipelines.find(key); + if (it != repeat_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "repeat"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + case GGML_TYPE_I16: + defines.push_back("TYPE_I16"); + variant += "_i16"; + break; + default: + GGML_ABORT("Unsupported type for repeat shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_repeat, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + repeat_pipelines[key] = pipeline; + return repeat_pipelines[key]; + } + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index ccc34cb153..128b7dc3de 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1567,6 +1567,48 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } +static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector params = { ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / + ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src0->ne[0]), + (uint32_t) (src0->ne[1]), + (uint32_t) (src0->ne[2]), + (uint32_t) (src0->ne[3]), + (uint32_t) (dst->ne[0]), + (uint32_t) (dst->ne[1]), + (uint32_t) (dst->ne[2]) }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); @@ -2158,6 +2200,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_binary_op(ctx, src0, src1, node); case GGML_OP_CONCAT: return ggml_webgpu_concat(ctx, src0, src1, node); + case GGML_OP_REPEAT: + return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2919,10 +2963,10 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm /* .iface = */ { /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, /* .alloc_buffer = */ - ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ - ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ - ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */ - ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false + ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ + ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ + ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */ + ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false }, /* .device = */ dev, @@ -3000,6 +3044,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_CONCAT: supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); break; + case GGML_OP_REPEAT: + supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16); + break; case GGML_OP_CPY: case GGML_OP_CONT: supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl new file mode 100644 index 0000000000..6e2a1a8b61 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl @@ -0,0 +1,67 @@ +enable f16; + +struct Params { + ne: u32, + + offset_src0: u32, + offset_dst: u32, + + stride_src0_0: u32, + stride_src0_1: u32, + stride_src0_2: u32, + stride_src0_3: u32, + + a_ne0: u32, + a_ne1: u32, + a_ne2: u32, + a_ne3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, +}; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_I32 +#define DataType i32 +#endif +#ifdef TYPE_I16 +// same size (16-bit) is sufficient for repeat +#define DataType f16 +#endif + +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let a_i0 = i0 % params.a_ne0; + let a_i1 = i1 % params.a_ne1; + let a_i2 = i2 % params.a_ne2; + let a_i3 = i3 % params.a_ne3; + + let a_index = a_i0 * params.stride_src0_0 + + a_i1 * params.stride_src0_1 + + a_i2 * params.stride_src0_2 + + a_i3 * params.stride_src0_3; + + dst[params.offset_dst + gid.x] = src0[params.offset_src0 + a_index]; + } +}