From 49a7564ac1e845d7d7d61e9236242282dc5d8248 Mon Sep 17 00:00:00 2001 From: Abhijit Ramesh Date: Mon, 2 Mar 2026 19:35:11 -0800 Subject: [PATCH] ggml webgpu: fix workgroup dispatch limit for large batch sizes (#19965) * ggml-webgpu: fix workgroup dispatch limit for large batch sizes WebGPU limits workgroup sizes to 65535 per dimension. Large MUL_MAT operations with batch sizes exceedeing this limi would fail. * add compute_2d_workgroups() helper to split total workgroup ID across X/Y dimensions * update mul_mat_reg_tile.wgsl to reconstruct linear workgroup ID from 2D dispatch * update mul_mat_subgroup_matrix.wgsl to reconstruct linear workgroup ID from 2D dispatch * update mul_mat.wgsl to compute global index from 2D workgroup coordinates * refactor all three mul_mat dispatch paths to use the shared helper * ggml-webgpu: add bounds checking for over-dispatched workgroups 2D workgroup dispatch can over-dispatch when total workgroups don't divide evenly into the 65535 per-dimension limit. Extra workgroups would compute invalid batch indices, causing memory corruption. * add batch_idx bound check to mul_mat_reg_tile.wgsl and mul_mat_subgroup_matrix.wgsl to prevent over-dispatched workgroups from accessing invalid memory * fixes test failures with large batch sizes (eg., bs=[128, 1024]) * ggml-webgpu: add back TODO for spliting large sizes into batches * Optimize 2d workgroup provisioning * Set some parameters that increase speed --------- Co-authored-by: Reese Levine --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 28 ++++++++++++------- .../src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl | 13 ++++++--- .../wgsl-shaders/mul_mat_reg_tile.wgsl | 14 ++++++++-- .../wgsl-shaders/mul_mat_subgroup_matrix.wgsl | 14 ++++++++-- 4 files changed, 49 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 913cf7f882..19451618ec 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -31,6 +31,13 @@ #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1)) #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) +// Return a rectangular grid of workgroups with minimal over-provisioned workgroups. +// Assumes that the total number of workgroups does not exceed max_per_dim^2. +static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) { + wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim)); + wg_x = CEIL_DIV(total_wg, wg_y); +} + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 512 @@ -69,8 +76,8 @@ /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 16u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u +#define WEBGPU_NUM_PARAM_BUFS 48u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool @@ -1146,8 +1153,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, }; // Calculate workgroup dimensions - uint32_t wg_x = 1; - uint32_t wg_y = 1; + uint32_t wg_x = 1; + uint32_t wg_y = 1; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; if (use_fast && is_vec) { auto decisions = static_cast(pipeline.context.get()); @@ -1155,9 +1163,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; - // TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups - wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); - wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else if (use_fast) { auto decisions = static_cast(pipeline.context.get()); @@ -1176,12 +1182,14 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, wg_m = CEIL_DIV(dst->ne[0], tile_m_s); wg_n = CEIL_DIV(dst->ne[1], tile_n_s); } - wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + } else { // legacy auto decisions = static_cast(pipeline.context.get()); uint32_t wg_size = decisions->wg_size; - wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); - wg_y = 1; + uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index 6aba47317c..5b9f5b3622 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -679,19 +679,24 @@ struct MulMatParams { @group(0) @binding(3) var params: MulMatParams; @compute @workgroup_size(256) -fn main(@builtin(global_invocation_id) global_id: vec3) { +fn main(@builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let global_idx = wg_linear * 256u + local_id.x; + let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_id.x >= total) { + if (global_idx >= total) { return; } let dst2_stride = params.m * params.n; let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - let dst3_idx = global_id.x / dst3_stride; + let dst3_idx = global_idx / dst3_stride; let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_id.x % dst3_stride; + let dst3_rem = global_idx % dst3_stride; let dst2_idx = dst3_rem / dst2_stride; let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index 771e5cd1ee..761e3017c1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -54,7 +54,8 @@ var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3) { + @builtin(local_invocation_id) local_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { let thread_id = local_id.x; let local_m = get_local_m(thread_id); @@ -64,9 +65,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); let wg_per_matrix = wg_m_count * wg_n_count; - let batch_idx = wg_id.x / wg_per_matrix; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let wg_in_batch = wg_id.x % wg_per_matrix; + let batch_idx = wg_linear / wg_per_matrix; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (batch_idx >= total_batches) { + return; + } + + let wg_in_batch = wg_linear % wg_per_matrix; let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 64529e03cd..9f9ef279f2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -69,7 +69,8 @@ var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, - @builtin(subgroup_id) subgroup_id: u32) { + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_workgroups) num_wg: vec3) { let thread_id = local_id.x; let subgroup_m = subgroup_id % SUBGROUP_M; @@ -79,9 +80,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; let wg_per_matrix = wg_m_count * wg_n_count; - let batch_idx = wg_id.x / wg_per_matrix; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let wg_in_batch = wg_id.x % wg_per_matrix; + let batch_idx = wg_linear / wg_per_matrix; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (batch_idx >= total_batches) { + return; + } + + let wg_in_batch = wg_linear % wg_per_matrix; let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count;