From 9b1b96187e5ee2b5338798e00b5ffd9d05016dd0 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 31 Dec 2025 17:45:34 -0600 Subject: [PATCH] vulkan: handle quantize_q8_1 overflowing the max workgroup count --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 7 ++++++- .../vulkan-shaders/quantize_q8_1.comp | 16 ++++++++-------- tests/test-backend-ops.cpp | 1 + 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 493ee9c9a4..d7f1d865e0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -6721,7 +6721,12 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); + const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]); + // clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks. + const uint64_t max_elements = std::min(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits::max()); + const uint32_t elements = std::min(ne, static_cast(max_elements)); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ ne, num_blocks }, { elements, 1, 1 }); ggml_vk_sync_buffers(ctx, subctx); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp index 20e45d0253..7ea29a07e3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -15,6 +15,7 @@ layout (push_constant) uniform parameter { uint ne; + uint num_blocks; } p; #include "types.glsl" @@ -33,8 +34,7 @@ layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];}; shared float shmem[GROUP_SIZE]; #endif -void quantize() { - const uint wgid = gl_WorkGroupID.x; +void quantize(const uint wgid) { const uint tid = INVOCATION_ID; // Each thread handles a vec4, so 8 threads handle a block @@ -45,11 +45,7 @@ void quantize() { const uint ib = wgid * blocks_per_group + block_in_wg; const uint iqs = tid % 8; -#ifndef QBLOCK_X4 - if (ib >= gl_NumWorkGroups.x * blocks_per_group) { - return; - } -#else +#ifdef QBLOCK_X4 const uint ibx4_outer = ib / 4; const uint ibx4_inner = ib % 4; @@ -123,5 +119,9 @@ void quantize() { } void main() { - quantize(); + uint wgid = gl_WorkGroupID.x; + while (wgid < p.num_blocks) { + quantize(wgid); + wgid += gl_NumWorkGroups.x; + } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0b981b1788..57357a635c 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7539,6 +7539,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 576, 512, 576, {1,1}, {1,1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 1, 512, 32768, {1, 1}, {1, 1})); #if 0 // test the mat-mat path for Metal