diff --git a/docs/ops.md b/docs/ops.md index 47534b1401..14b53f2c40 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -62,7 +62,7 @@ Legend: | HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | -| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | +| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | | LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | | MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | diff --git a/docs/ops/WebGPU.csv b/docs/ops/WebGPU.csv index 56bae2f3c8..4b735d4579 100644 --- a/docs/ops/WebGPU.csv +++ b/docs/ops/WebGPU.csv @@ -5744,49 +5744,61 @@ "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","WebGPU" -"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000000,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000000,v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000000","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000000","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000000","support","0","no","WebGPU" -"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000000,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000000,v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","WebGPU" -"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000001,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000001,v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000001","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000001","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000001","support","0","no","WebGPU" -"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000001,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000001,v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","WebGPU" -"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000100,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000100,v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000100","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000100","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000100","support","0","no","WebGPU" -"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000100,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000100,v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","WebGPU" -"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.100000,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.100000,v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.100000","support","0","no","WebGPU" -"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.100000,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.100000,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=10.000000,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=10.000000,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=10.000000,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=10.000000,v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","WebGPU" "WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[3,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[6,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU" diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index ad665e4de9..9d16abf20d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -151,6 +151,26 @@ struct ggml_webgpu_get_rows_pipeline_key_hash { } }; +/** Row Norm **/ + +struct ggml_webgpu_row_norm_pipeline_key { + ggml_op op; + bool inplace; + + bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const { + return op == other.op && inplace == other.inplace; + } +}; + +struct ggml_webgpu_row_norm_pipeline_key_hash { + size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; @@ -438,6 +458,8 @@ class ggml_webgpu_shader_lib { std::unordered_map argsort_pipelines; // key is order std::unordered_map argsort_merge_pipelines; // key is order std::unordered_map cumsum_pipelines; // key is fixed, no variants yet + std::unordered_map + row_norm_pipelines; // op/inplace std::unordered_map get_rows_pipelines; // src_type, vectorized std::unordered_map @@ -482,6 +504,44 @@ class ggml_webgpu_shader_lib { return sum_rows_pipelines[1]; } + webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_row_norm_pipeline_key key = { + .op = context.dst->op, + .inplace = context.inplace, + }; + + auto it = row_norm_pipelines.find(key); + if (it != row_norm_pipelines.end()) { + return it->second; + } + std::vector defines; + std::string variant; + + switch (key.op) { + case GGML_OP_RMS_NORM: + defines.push_back("OP_RMS_NORM"); + variant = "rms_norm"; + break; + case GGML_OP_L2_NORM: + defines.push_back("OP_L2_NORM"); + variant = "l2_norm"; + break; + default: + GGML_ABORT("Unsupported op for row_norm shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_row_norm, defines); + row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + return row_norm_pipelines[key]; + } + webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) { bool vec4 = context.src0->ne[0] % 4 == 0; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4b0eeac0f4..f7973df682 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -366,7 +366,6 @@ struct webgpu_context_struct { std::map> cpy_pipelines; // src_type, dst_type - std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace std::map>> glu_pipelines; // glu_op, type, split @@ -1598,8 +1597,8 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - int inplace = ggml_webgpu_tensor_equal(src, dst); +static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + bool inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), @@ -1630,8 +1629,15 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params, - entries, ggml_nrows(src)); + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .dst = dst, + .max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE, + .inplace = inplace, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src)); } static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, @@ -2192,7 +2198,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_REPEAT: return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: - return ggml_webgpu_rms_norm(ctx, src0, node); + case GGML_OP_L2_NORM: + return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: return ggml_webgpu_rope(ctx, src0, src1, src2, node); case GGML_OP_GLU: @@ -2616,15 +2623,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } -static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - - webgpu_ctx->rms_norm_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants); - webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); -} - static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); @@ -2909,7 +2907,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); ggml_webgpu_init_rope_pipeline(webgpu_ctx); ggml_webgpu_init_glu_pipeline(webgpu_ctx); ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); @@ -3120,6 +3117,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; case GGML_OP_ROPE: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl similarity index 79% rename from ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl index 712b921f1a..7777944941 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl @@ -1,21 +1,11 @@ -#define(VARIANTS) - -[ - { - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_SUFFIX": "inplace", - "DECLS": ["INPLACE"] - }, -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) +#ifdef INPLACE +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + src[dst_offset] = scale * src[src_offset]; +} +@group(0) @binding(1) +var params: Params; +#else fn update(src_offset: u32, dst_offset: u32, scale: f32) { dst[dst_offset] = scale * src[src_offset]; } @@ -25,23 +15,7 @@ var dst: array; @group(0) @binding(2) var params: Params; - -#enddecl(NOT_INPLACE) - -#decl(INPLACE) - -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - src[dst_offset] = scale * src[src_offset]; -} - -@group(0) @binding(1) -var params: Params; - -#enddecl(INPLACE) - -#end(DECLS) - -#define(SHADER) +#endif struct Params { offset_src: u32, // in elements @@ -68,12 +42,9 @@ struct Params { @group(0) @binding(0) var src: array; -DECLS +var scratch: array; -override wg_size: u32; -var scratch: array; - -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { @@ -86,7 +57,7 @@ fn main(@builtin(workgroup_id) wid: vec3, let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - let elems = (params.ne0 + wg_size - 1) / wg_size; + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; var sum = 0.0f; var col = lid.x; @@ -95,12 +66,12 @@ fn main(@builtin(workgroup_id) wid: vec3, break; } sum += pow(src[i_src_row + col], 2.0); - col += wg_size; + col += WG_SIZE; } scratch[lid.x] = sum; workgroupBarrier(); - var offset = wg_size / 2; + var offset: u32 = WG_SIZE / 2; while (offset > 0) { if (lid.x < offset) { scratch[lid.x] += scratch[lid.x + offset]; @@ -110,14 +81,17 @@ fn main(@builtin(workgroup_id) wid: vec3, } sum = scratch[0]; +#ifdef OP_RMS_NORM let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); +#elif OP_L2_NORM + let scale = 1.0/max(sqrt(sum), params.eps); +#endif col = lid.x; for (var j: u32 = 0; j < elems; j++) { if (col >= params.ne0) { break; } update(i_src_row + col, i_dst_row + col, scale); - col += wg_size; + col += WG_SIZE; } } -#end(SHADER)