ggml webgpu: fix workgroup dispatch limit for large batch sizes (#19965)

* ggml-webgpu: fix workgroup dispatch limit for large batch sizes

WebGPU limits workgroup sizes to 65535 per dimension. Large MUL_MAT
operations with batch sizes exceedeing this limi would fail.

* add compute_2d_workgroups() helper to split total workgroup ID across
X/Y dimensions

* update mul_mat_reg_tile.wgsl to reconstruct linear workgroup ID from 2D
   dispatch

* update mul_mat_subgroup_matrix.wgsl to reconstruct linear workgroup ID
  from 2D dispatch

* update mul_mat.wgsl to compute global index from 2D workgroup
  coordinates

* refactor all three mul_mat dispatch paths to use the shared helper

* ggml-webgpu: add bounds checking for over-dispatched workgroups

2D workgroup dispatch can over-dispatch when total workgroups don't
divide evenly into the 65535 per-dimension limit. Extra workgroups
would compute invalid batch indices, causing memory corruption.

* add batch_idx bound check to mul_mat_reg_tile.wgsl and
mul_mat_subgroup_matrix.wgsl to prevent over-dispatched workgroups
from accessing invalid memory

* fixes test failures with large batch sizes (eg., bs=[128, 1024])

* ggml-webgpu: add back TODO for spliting large sizes into batches

* Optimize 2d workgroup provisioning

* Set some parameters that increase speed

---------

Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
Abhijit Ramesh 2026-03-02 19:35:11 -08:00 committed by GitHub
parent 4d828bd1ab
commit 49a7564ac1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 49 additions and 20 deletions

View File

@ -31,6 +31,13 @@
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
// Return a rectangular grid of workgroups with minimal over-provisioned workgroups.
// Assumes that the total number of workgroups does not exceed max_per_dim^2.
static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) {
wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim));
wg_x = CEIL_DIV(total_wg, wg_y);
}
#ifdef GGML_WEBGPU_DEBUG
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
# define WEBGPU_DEBUG_BUF_ELEMS 512
@ -69,8 +76,8 @@
/* Constants */
#define WEBGPU_NUM_PARAM_BUFS 16u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
#define WEBGPU_NUM_PARAM_BUFS 48u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
// parameter buffer pool
@ -1146,8 +1153,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
};
// Calculate workgroup dimensions
uint32_t wg_x = 1;
uint32_t wg_y = 1;
uint32_t wg_x = 1;
uint32_t wg_y = 1;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (use_fast && is_vec) {
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
@ -1155,9 +1163,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
uint32_t total_wg = output_groups * batches;
// TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups
wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else if (use_fast) {
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
@ -1176,12 +1182,14 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
}
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
} else { // legacy
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_size = decisions->wg_size;
wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
wg_y = 1;
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);

View File

@ -679,19 +679,24 @@ struct MulMatParams {
@group(0) @binding(3) var<uniform> params: MulMatParams;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
fn main(@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
let global_idx = wg_linear * 256u + local_id.x;
let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
if (global_id.x >= total) {
if (global_idx >= total) {
return;
}
let dst2_stride = params.m * params.n;
let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
let dst3_idx = global_id.x / dst3_stride;
let dst3_idx = global_idx / dst3_stride;
let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
let src13_idx = dst3_idx; // src1 is not broadcast
let dst3_rem = global_id.x % dst3_stride;
let dst3_rem = global_idx % dst3_stride;
let dst2_idx = dst3_rem / dst2_stride;
let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension

View File

@ -54,7 +54,8 @@ var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>) {
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let thread_id = local_id.x;
let local_m = get_local_m(thread_id);
@ -64,9 +65,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
let wg_per_matrix = wg_m_count * wg_n_count;
let batch_idx = wg_id.x / wg_per_matrix;
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
let wg_in_batch = wg_id.x % wg_per_matrix;
let batch_idx = wg_linear / wg_per_matrix;
let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
if (batch_idx >= total_batches) {
return;
}
let wg_in_batch = wg_linear % wg_per_matrix;
let wg_m = wg_in_batch % wg_m_count;
let wg_n = wg_in_batch / wg_m_count;

View File

@ -69,7 +69,8 @@ var<workgroup> shmem: array<f16, SHMEM_SIZE>;
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(subgroup_id) subgroup_id: u32) {
@builtin(subgroup_id) subgroup_id: u32,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let thread_id = local_id.x;
let subgroup_m = subgroup_id % SUBGROUP_M;
@ -79,9 +80,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
let wg_per_matrix = wg_m_count * wg_n_count;
let batch_idx = wg_id.x / wg_per_matrix;
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
let wg_in_batch = wg_id.x % wg_per_matrix;
let batch_idx = wg_linear / wg_per_matrix;
let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
if (batch_idx >= total_batches) {
return;
}
let wg_in_batch = wg_linear % wg_per_matrix;
let wg_m = wg_in_batch % wg_m_count;
let wg_n = wg_in_batch / wg_m_count;