From 66d7c143597d5dc8adce08d23c17444fb20d0c7c Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 11 Feb 2026 00:55:02 -0600 Subject: [PATCH] vulkan: split mul_mat into multiple dispatches to avoid overflow The batch dimensions can be greater than the max workgroup count limit, in which case we need to split into multiple dispatches and pass the base index through a push constant. Fall back for the less common p021 and nc variants. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 78 ++++++++++++------- .../vulkan-shaders/mul_mat_vec_base.glsl | 5 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 8 +- .../vulkan-shaders/mul_mm_cm2.comp | 8 +- 4 files changed, 64 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 72097ffd0f..a7728a95e4 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -942,6 +942,7 @@ struct vk_mat_mat_push_constants { uint32_t M; uint32_t N; uint32_t K; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t base_work_group_z; uint32_t num_batches; uint32_t k_split; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; uint32_t padded_N; @@ -961,6 +962,7 @@ struct vk_mat_vec_push_constants { uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t fusion_flags; + uint32_t base_work_group_y; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; @@ -6766,8 +6768,16 @@ static void ggml_vk_matmul( uint32_t padded_n) { VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); if (split_k == 1) { - const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); + uint32_t base_work_group_z = 0; + while (base_work_group_z < batch) { + uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z }); + base_work_group_z += groups_z; + } return; } @@ -6781,9 +6791,16 @@ static void ggml_vk_matmul( uint32_t k_split = CEIL_DIV(k, split_k); k_split = ROUNDUP_POW2(k_split, 256); - const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; - // Make sure enough workgroups get assigned for split k to work - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + uint32_t base_work_group_z = 0; + while (base_work_group_z < batch) { + uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; + // Make sure enough workgroups get assigned for split k to work + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z }); + } ggml_vk_sync_buffers(ctx, subctx); const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); @@ -7179,7 +7196,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (qx_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } @@ -7477,7 +7493,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (quantize_y) { ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); @@ -7572,22 +7587,28 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1; } - // compute - const vk_mat_vec_push_constants pc = { - (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - stride_batch_x, stride_batch_y, stride_batch_d, - fusion_flags, - (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, - }; - ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, - { - d_X, - d_Y, - d_D, - d_F0, - d_F1, - }, - pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); + uint32_t base_work_group_y = 0; + while (base_work_group_y < ne12 * ne13) { + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + + uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + const vk_mat_vec_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + stride_batch_x, stride_batch_y, stride_batch_d, + fusion_flags, base_work_group_y, + (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { + d_X, + d_Y, + d_D, + d_F0, + d_F1, + }, + pc, { groups_x, groups_y, groups_z }); + base_work_group_y += groups_y; + } if (x_non_contig) { ctx->prealloc_x_need_sync = true; @@ -7825,10 +7846,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c src1->nb[2] <= src1->nb[1] && src1->nb[1] <= src1->nb[3] && src0->ne[3] == 1 && - src1->ne[3] == 1) { + src1->ne[3] == 1 && + src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) { ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx); } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && - !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { + !ggml_is_permuted(src0) && !ggml_is_permuted(src1) && + src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] && + src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) { ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx); // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) // when ne12 and ne13 are one. @@ -11543,7 +11569,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } - ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); @@ -12052,7 +12077,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, // y[i] = i % k; } - ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index 4f2c700306..4aeda68c7f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -32,6 +32,7 @@ layout (push_constant) uniform parameter uint expert_i1; uint nbi1; #else + uint base_work_group_y; uint ne02; uint ne12; uint broadcast2; @@ -45,9 +46,9 @@ uint expert_id; void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #ifdef MUL_MAT_ID - const uint expert_i0 = gl_GlobalInvocationID.y; + const uint expert_i0 = gl_WorkGroupID.y; #else - const uint batch_idx = gl_GlobalInvocationID.y; + const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y; #endif #ifndef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 775e9a70f6..79344d3300 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -90,6 +90,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -139,7 +141,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -149,7 +151,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -366,7 +368,7 @@ void main() { const uint dc = ic * BN + warp_c * WN; #ifndef MUL_MAT_ID - const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif #ifdef COOPMAT diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index b6614d2fc5..717d124e01 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -53,6 +53,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -197,7 +199,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -215,7 +217,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -255,7 +257,7 @@ void main() { #else uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K); uint pos_b = batch_idx * p.batch_stride_b; - uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif uint stride_a = p.stride_a / QUANT_K;