vulkan: handle quantize_q8_1 overflowing the max workgroup count

This commit is contained in:
Jeff Bolz 2025-12-31 17:45:34 -06:00
parent cd78e57c3a
commit 9b1b96187e
3 changed files with 15 additions and 9 deletions

View File

@ -6721,7 +6721,12 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]);
// clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.
const uint64_t max_elements = std::min<uint64_t>(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits<uint32_t>::max());
const uint32_t elements = std::min(ne, static_cast<uint32_t>(max_elements));
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 2>{ ne, num_blocks }, { elements, 1, 1 });
ggml_vk_sync_buffers(ctx, subctx);
}

View File

@ -15,6 +15,7 @@
layout (push_constant) uniform parameter
{
uint ne;
uint num_blocks;
} p;
#include "types.glsl"
@ -33,8 +34,7 @@ layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
shared float shmem[GROUP_SIZE];
#endif
void quantize() {
const uint wgid = gl_WorkGroupID.x;
void quantize(const uint wgid) {
const uint tid = INVOCATION_ID;
// Each thread handles a vec4, so 8 threads handle a block
@ -45,11 +45,7 @@ void quantize() {
const uint ib = wgid * blocks_per_group + block_in_wg;
const uint iqs = tid % 8;
#ifndef QBLOCK_X4
if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
return;
}
#else
#ifdef QBLOCK_X4
const uint ibx4_outer = ib / 4;
const uint ibx4_inner = ib % 4;
@ -123,5 +119,9 @@ void quantize() {
}
void main() {
quantize();
uint wgid = gl_WorkGroupID.x;
while (wgid < p.num_blocks) {
quantize(wgid);
wgid += gl_NumWorkGroups.x;
}
}

View File

@ -7539,6 +7539,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 576, 512, 576, {1,1}, {1,1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 1, 512, 32768, {1, 1}, {1, 1}));
#if 0
// test the mat-mat path for Metal