Fix bug in dispatching large matrix-vector multiplication

This commit is contained in:
Reese Levine 2026-02-11 21:57:28 -08:00
parent 4d3daf80f8
commit 36f28fe8b9
1 changed files with 2 additions and 1 deletions

View File

@ -1168,7 +1168,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
uint32_t total_wg = output_groups * batches;
wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
// TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups
wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
} else {
pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];