diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 913cf7f882..19451618ec 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -31,6 +31,13 @@ #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1)) #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) +// Return a rectangular grid of workgroups with minimal over-provisioned workgroups. +// Assumes that the total number of workgroups does not exceed max_per_dim^2. +static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) { + wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim)); + wg_x = CEIL_DIV(total_wg, wg_y); +} + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 512 @@ -69,8 +76,8 @@ /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 16u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u +#define WEBGPU_NUM_PARAM_BUFS 48u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool @@ -1146,8 +1153,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, }; // Calculate workgroup dimensions - uint32_t wg_x = 1; - uint32_t wg_y = 1; + uint32_t wg_x = 1; + uint32_t wg_y = 1; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; if (use_fast && is_vec) { auto decisions = static_cast(pipeline.context.get()); @@ -1155,9 +1163,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; - // TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups - wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); - wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else if (use_fast) { auto decisions = static_cast(pipeline.context.get()); @@ -1176,12 +1182,14 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, wg_m = CEIL_DIV(dst->ne[0], tile_m_s); wg_n = CEIL_DIV(dst->ne[1], tile_n_s); } - wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + } else { // legacy auto decisions = static_cast(pipeline.context.get()); uint32_t wg_size = decisions->wg_size; - wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); - wg_y = 1; + uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index 6aba47317c..5b9f5b3622 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -679,19 +679,24 @@ struct MulMatParams { @group(0) @binding(3) var params: MulMatParams; @compute @workgroup_size(256) -fn main(@builtin(global_invocation_id) global_id: vec3) { +fn main(@builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let global_idx = wg_linear * 256u + local_id.x; + let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_id.x >= total) { + if (global_idx >= total) { return; } let dst2_stride = params.m * params.n; let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - let dst3_idx = global_id.x / dst3_stride; + let dst3_idx = global_idx / dst3_stride; let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_id.x % dst3_stride; + let dst3_rem = global_idx % dst3_stride; let dst2_idx = dst3_rem / dst2_stride; let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index 771e5cd1ee..761e3017c1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -54,7 +54,8 @@ var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3) { + @builtin(local_invocation_id) local_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { let thread_id = local_id.x; let local_m = get_local_m(thread_id); @@ -64,9 +65,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); let wg_per_matrix = wg_m_count * wg_n_count; - let batch_idx = wg_id.x / wg_per_matrix; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let wg_in_batch = wg_id.x % wg_per_matrix; + let batch_idx = wg_linear / wg_per_matrix; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (batch_idx >= total_batches) { + return; + } + + let wg_in_batch = wg_linear % wg_per_matrix; let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 64529e03cd..9f9ef279f2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -69,7 +69,8 @@ var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, - @builtin(subgroup_id) subgroup_id: u32) { + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_workgroups) num_wg: vec3) { let thread_id = local_id.x; let subgroup_m = subgroup_id % SUBGROUP_M; @@ -79,9 +80,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; let wg_per_matrix = wg_m_count * wg_n_count; - let batch_idx = wg_id.x / wg_per_matrix; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let wg_in_batch = wg_id.x % wg_per_matrix; + let batch_idx = wg_linear / wg_per_matrix; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (batch_idx >= total_batches) { + return; + } + + let wg_in_batch = wg_linear % wg_per_matrix; let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count;