ggml-webgpu: Update the `RMS_NORM` preprocessor and add `L2_NORM` (#20665)
* Update the preprocessor of RMS_NORM and add L2_NORM. * Fix the name of rms_norm to row_norm.
This commit is contained in:
parent
ea01d196d7
commit
509a31d00f
|
|
@ -62,7 +62,7 @@ Legend:
|
|||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
|
|
|||
|
|
@ -5744,49 +5744,61 @@
|
|||
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000000,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000000,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000000,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000000,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000001,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000001,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000001","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000001","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000001","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000001,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000001,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000100,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.000100,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000100","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000100","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.000100","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000100,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.000100,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.100000,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=0.100000,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.100000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.100000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[1025,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM_BACK","type=f32,ne=[1025,5,4,3],eps=0.100000","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.100000,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=0.100000,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=10.000000,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3],eps=10.000000,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=10.000000,v=0","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","L2_NORM","type=f32,ne=[1025,5,4,3],eps=10.000000,v=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","WebGPU"
|
||||
"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[3,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU"
|
||||
"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[6,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
|
|
@ -151,6 +151,26 @@ struct ggml_webgpu_get_rows_pipeline_key_hash {
|
|||
}
|
||||
};
|
||||
|
||||
/** Row Norm **/
|
||||
|
||||
struct ggml_webgpu_row_norm_pipeline_key {
|
||||
ggml_op op;
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const {
|
||||
return op == other.op && inplace == other.inplace;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_row_norm_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Pad **/
|
||||
struct ggml_webgpu_pad_pipeline_key {
|
||||
bool circular;
|
||||
|
|
@ -438,6 +458,8 @@ class ggml_webgpu_shader_lib {
|
|||
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
|
||||
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
|
||||
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
|
||||
std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash>
|
||||
row_norm_pipelines; // op/inplace
|
||||
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
|
||||
get_rows_pipelines; // src_type, vectorized
|
||||
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
|
||||
|
|
@ -482,6 +504,44 @@ class ggml_webgpu_shader_lib {
|
|||
return sum_rows_pipelines[1];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_row_norm_pipeline_key key = {
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
};
|
||||
|
||||
auto it = row_norm_pipelines.find(key);
|
||||
if (it != row_norm_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::vector<std::string> defines;
|
||||
std::string variant;
|
||||
|
||||
switch (key.op) {
|
||||
case GGML_OP_RMS_NORM:
|
||||
defines.push_back("OP_RMS_NORM");
|
||||
variant = "rms_norm";
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
defines.push_back("OP_L2_NORM");
|
||||
variant = "l2_norm";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported op for row_norm shader");
|
||||
}
|
||||
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
|
||||
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
return row_norm_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
bool vec4 = context.src0->ne[0] % 4 == 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -366,7 +366,6 @@ struct webgpu_context_struct {
|
|||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
|
||||
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
|
||||
|
||||
|
|
@ -1598,8 +1597,8 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src
|
|||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
|
|
@ -1630,8 +1629,15 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
|
||||
entries, ggml_nrows(src));
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.dst = dst,
|
||||
.max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE,
|
||||
.inplace = inplace,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src));
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
|
||||
|
|
@ -2192,7 +2198,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|||
case GGML_OP_REPEAT:
|
||||
return ggml_webgpu_repeat(ctx, src0, node);
|
||||
case GGML_OP_RMS_NORM:
|
||||
return ggml_webgpu_rms_norm(ctx, src0, node);
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_webgpu_row_norm(ctx, src0, node);
|
||||
case GGML_OP_ROPE:
|
||||
return ggml_webgpu_rope(ctx, src0, src1, src2, node);
|
||||
case GGML_OP_GLU:
|
||||
|
|
@ -2616,15 +2623,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
|||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
||||
|
||||
webgpu_ctx->rms_norm_pipelines[0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
|
||||
webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
|
||||
webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
||||
|
||||
|
|
@ -2909,7 +2907,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
|||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
|
||||
|
||||
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
|
||||
|
|
@ -3120,6 +3117,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
break;
|
||||
}
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
|
|
|
|||
|
|
@ -1,21 +1,11 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "inplace",
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(NOT_INPLACE)
|
||||
#ifdef INPLACE
|
||||
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||
src[dst_offset] = scale * src[src_offset];
|
||||
}
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
#else
|
||||
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||
dst[dst_offset] = scale * src[src_offset];
|
||||
}
|
||||
|
|
@ -25,23 +15,7 @@ var<storage, read_write> dst: array<f32>;
|
|||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(NOT_INPLACE)
|
||||
|
||||
#decl(INPLACE)
|
||||
|
||||
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||
src[dst_offset] = scale * src[src_offset];
|
||||
}
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(INPLACE)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
|
|
@ -68,12 +42,9 @@ struct Params {
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
DECLS
|
||||
var<workgroup> scratch: array<f32, WG_SIZE>;
|
||||
|
||||
override wg_size: u32;
|
||||
var<workgroup> scratch: array<f32, wg_size>;
|
||||
|
||||
@compute @workgroup_size(wg_size)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
|
||||
|
|
@ -86,7 +57,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|||
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
|
||||
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||
|
||||
let elems = (params.ne0 + wg_size - 1) / wg_size;
|
||||
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
|
||||
|
||||
var sum = 0.0f;
|
||||
var col = lid.x;
|
||||
|
|
@ -95,12 +66,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|||
break;
|
||||
}
|
||||
sum += pow(src[i_src_row + col], 2.0);
|
||||
col += wg_size;
|
||||
col += WG_SIZE;
|
||||
}
|
||||
|
||||
scratch[lid.x] = sum;
|
||||
workgroupBarrier();
|
||||
var offset = wg_size / 2;
|
||||
var offset: u32 = WG_SIZE / 2;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
scratch[lid.x] += scratch[lid.x + offset];
|
||||
|
|
@ -110,14 +81,17 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|||
}
|
||||
sum = scratch[0];
|
||||
|
||||
#ifdef OP_RMS_NORM
|
||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
|
||||
#elif OP_L2_NORM
|
||||
let scale = 1.0/max(sqrt(sum), params.eps);
|
||||
#endif
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
break;
|
||||
}
|
||||
update(i_src_row + col, i_dst_row + col, scale);
|
||||
col += wg_size;
|
||||
col += WG_SIZE;
|
||||
}
|
||||
}
|
||||
#end(SHADER)
|
||||
Loading…
Reference in New Issue