Move rms_norm to split row approach

This commit is contained in:
Reese Levine 2025-09-24 15:48:26 -07:00
parent 483615da76
commit 27b893a6f8
3 changed files with 87 additions and 52 deletions

View File

@ -28,6 +28,7 @@
/* Constants */
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
#define WEBGPU_WAIT_ANY_BATCH_SIZE 64
#define WEBGPU_MUL_MAT_WG_SIZE 64
#define WEBGPU_NUM_PARAM_BUFS 100
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
@ -35,6 +36,9 @@
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
// For operations which process a row in parallel, this seems like a reasonable default
#define WEBGPU_ROW_SPLIT_WG_SIZE 64
/* End Constants */
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
@ -257,8 +261,12 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
}),
UINT64_MAX);
} else {
// existing callbacks, wait on them
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
// WebGPU implementations may limit the number of futures that can be waited on at once,
// so wait in batches (64 is what Dawn supports).
for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) {
size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size());
ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX);
}
ctx->callback_futures.clear();
}
}
@ -727,9 +735,7 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x,
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src),
ggml_op_name(dst->op));
}
@ -1311,11 +1317,11 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
return reinterpret_cast<ggml_guid_t>((void *) guid_str);
}
// The max workgroup size is a common constant
static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
// Workgroup size is a common constant
static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->max_wg_size_x;
constants[0].value = wg_size;
return constants;
}
@ -1383,11 +1389,11 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
}
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
"get_rows_f32_vec", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
@ -1437,7 +1443,7 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
@ -1449,7 +1455,7 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
@ -1461,7 +1467,7 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
@ -1473,7 +1479,7 @@ static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
@ -1485,7 +1491,7 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
@ -1497,7 +1503,7 @@ static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace,
@ -1505,7 +1511,7 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32,
"rope_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1],
@ -1525,7 +1531,7 @@ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
// reglu
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0],
wgsl_reglu_f32, "reglu_f32", constants);
@ -1579,7 +1585,7 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
@ -1587,9 +1593,7 @@ static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = 64;
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32,
"soft_max_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace,

View File

@ -71,14 +71,14 @@ var<storage, read_write> src: array<f32>;
DECLS
override wg_size: u32;
var<workgroup> scratch: array<f32, wg_size>;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
return;
}
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
// one thread per row
var i = gid.x;
var i = wid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
@ -86,13 +86,38 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + wg_size - 1) / wg_size;
var sum = 0.0f;
for (var j: u32 = 0; j < params.ne0; j++) {
sum += src[i_src_row + j] * src[i_src_row + j];
var col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
sum += pow(src[i_src_row + col], 2.0);
col += wg_size;
}
scratch[lid.x] = sum;
workgroupBarrier();
var offset = wg_size / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
}
offset = offset / 2;
workgroupBarrier();
}
sum = scratch[0];
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
for (var j: u32 = 0; j < params.ne0; j++) {
update(i_src_row + j, i_dst_row + j, scale);
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
break;
}
update(i_src_row + col, i_dst_row + col, scale);
col += wg_size;
}
}
#end(SHADER)

View File

@ -276,15 +276,17 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
var cache: array<f32, CACHE_SIZE>;
var max_val = lower_max_bound(i2);
var col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
let col = j * wg_size + lid.x;
if (col < params.ne0) {
let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
max_val = max(max_val, val);
if (col < CACHE_SIZE) {
cache[col] = val;
}
if (col >= params.ne0) {
break;
}
let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
max_val = max(max_val, val);
if (col < CACHE_SIZE) {
cache[col] = val;
}
col += wg_size;
}
scratch[lid.x] = max_val;
@ -300,19 +302,21 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
let row_max = scratch[0];
var sum = 0.0f;
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
let col = j * wg_size + lid.x;
if (col < params.ne0) {
let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
cache[col], col < CACHE_SIZE);
let ex = exp(val - row_max);
sum += ex;
if (col < CACHE_SIZE) {
cache[col] = ex;
} else {
update(i_dst_row + col, ex);
}
if (col >= params.ne0) {
break;
}
let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
cache[col], col < CACHE_SIZE);
let ex = exp(val - row_max);
sum += ex;
if (col < CACHE_SIZE) {
cache[col] = ex;
} else {
update(i_dst_row + col, ex);
}
col += wg_size;
}
scratch[lid.x] = sum;
@ -328,11 +332,13 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
let row_sum = add_sinks(scratch[0], i2, row_max);
let sum_recip = 1.0 / row_sum;
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
let col = j * wg_size + lid.x;
if (col < params.ne0) {
update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
if (col >= params.ne0) {
break;
}
update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
col += wg_size;
}
}
#end(SHADER)