ggml-webgpu: Update the `RMS_NORM` preprocessor and add `L2_NORM` (#20665)

* Update the preprocessor of RMS_NORM and add L2_NORM.

* Fix the name of rms_norm to row_norm.
This commit is contained in:
Masashi Yoshimura 2026-03-19 13:08:59 +09:00 committed by GitHub
parent ea01d196d7
commit 509a31d00f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 113 additions and 69 deletions

View File

@ -62,7 +62,7 @@ Legend:
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |

View File

@ -5744,49 +5744,61 @@
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000000,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000000,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000000","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000000","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000000","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000000,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000000,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000001,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000001,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000001","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000001","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000001","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000001,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000001,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000100,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000100,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000100","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000100","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000100","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000100,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000100,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.100000,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.100000,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.100000","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.100000","support","0","no","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.100000","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.100000,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.100000,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=10.000000,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=10.000000,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=10.000000,v=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=10.000000,v=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[3,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[6,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU"

Can't render this file because it is too large.

View File

@ -151,6 +151,26 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
}
};
/** Row Norm **/
struct ggml_webgpu_row_norm_pipeline_key {
ggml_op op;
bool inplace;
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
return op == other.op && inplace == other.inplace;
}
};
struct ggml_webgpu_row_norm_pipeline_key_hash {
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
/** Pad **/
struct ggml_webgpu_pad_pipeline_key {
bool circular;
@ -438,6 +458,8 @@ class ggml_webgpu_shader_lib {
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
row_norm_pipelines; // op/inplace
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
get_rows_pipelines; // src_type, vectorized
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
@ -482,6 +504,44 @@ class ggml_webgpu_shader_lib {
return sum_rows_pipelines[1];
}
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_row_norm_pipeline_key key = {
.op = context.dst->op,
.inplace = context.inplace,
};
auto it = row_norm_pipelines.find(key);
if (it != row_norm_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant;
switch (key.op) {
case GGML_OP_RMS_NORM:
defines.push_back("OP_RMS_NORM");
variant = "rms_norm";
break;
case GGML_OP_L2_NORM:
defines.push_back("OP_L2_NORM");
variant = "l2_norm";
break;
default:
GGML_ABORT("Unsupported op for row_norm shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
return row_norm_pipelines[key];
}
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
bool vec4 = context.src0->ne[0] % 4 == 0;

View File

@ -366,7 +366,6 @@ struct webgpu_context_struct {
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
@ -1598,8 +1597,8 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
int inplace = ggml_webgpu_tensor_equal(src, dst);
static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool inplace = ggml_webgpu_tensor_equal(src, dst);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
@ -1630,8 +1629,15 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
entries, ggml_nrows(src));
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.dst = dst,
.max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE,
.inplace = inplace,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src));
}
static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
@ -2192,7 +2198,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_REPEAT:
return ggml_webgpu_repeat(ctx, src0, node);
case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node);
case GGML_OP_L2_NORM:
return ggml_webgpu_row_norm(ctx, src0, node);
case GGML_OP_ROPE:
return ggml_webgpu_rope(ctx, src0, src1, src2, node);
case GGML_OP_GLU:
@ -2616,15 +2623,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
webgpu_ctx->rms_norm_pipelines[0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
@ -2909,7 +2907,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
@ -3120,6 +3117,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
}
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
case GGML_OP_ROPE:

View File

@ -1,21 +1,11 @@
#define(VARIANTS)
[
{
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_SUFFIX": "inplace",
"DECLS": ["INPLACE"]
},
]
#end(VARIANTS)
#define(DECLS)
#decl(NOT_INPLACE)
#ifdef INPLACE
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
src[dst_offset] = scale * src[src_offset];
}
@group(0) @binding(1)
var<uniform> params: Params;
#else
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
dst[dst_offset] = scale * src[src_offset];
}
@ -25,23 +15,7 @@ var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
src[dst_offset] = scale * src[src_offset];
}
@group(0) @binding(1)
var<uniform> params: Params;
#enddecl(INPLACE)
#end(DECLS)
#define(SHADER)
#endif
struct Params {
offset_src: u32, // in elements
@ -68,12 +42,9 @@ struct Params {
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
DECLS
var<workgroup> scratch: array<f32, WG_SIZE>;
override wg_size: u32;
var<workgroup> scratch: array<f32, wg_size>;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
@ -86,7 +57,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + wg_size - 1) / wg_size;
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
var sum = 0.0f;
var col = lid.x;
@ -95,12 +66,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
break;
}
sum += pow(src[i_src_row + col], 2.0);
col += wg_size;
col += WG_SIZE;
}
scratch[lid.x] = sum;
workgroupBarrier();
var offset = wg_size / 2;
var offset: u32 = WG_SIZE / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
@ -110,14 +81,17 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
}
sum = scratch[0];
#ifdef OP_RMS_NORM
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
#elif OP_L2_NORM
let scale = 1.0/max(sqrt(sum), params.eps);
#endif
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
update(i_src_row + col, i_dst_row + col, scale);
col += wg_size;
col += WG_SIZE;
}
}
#end(SHADER)