vulkan: handle quantize_q8_1 overflowing the max workgroup count
This commit is contained in:
parent
cd78e57c3a
commit
9b1b96187e
|
|
@ -6721,7 +6721,12 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
||||||
|
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
|
const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]);
|
||||||
|
// clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.
|
||||||
|
const uint64_t max_elements = std::min<uint64_t>(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits<uint32_t>::max());
|
||||||
|
const uint32_t elements = std::min(ne, static_cast<uint32_t>(max_elements));
|
||||||
|
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 2>{ ne, num_blocks }, { elements, 1, 1 });
|
||||||
ggml_vk_sync_buffers(ctx, subctx);
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
{
|
{
|
||||||
uint ne;
|
uint ne;
|
||||||
|
uint num_blocks;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
#include "types.glsl"
|
#include "types.glsl"
|
||||||
|
|
@ -33,8 +34,7 @@ layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
|
||||||
shared float shmem[GROUP_SIZE];
|
shared float shmem[GROUP_SIZE];
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void quantize() {
|
void quantize(const uint wgid) {
|
||||||
const uint wgid = gl_WorkGroupID.x;
|
|
||||||
const uint tid = INVOCATION_ID;
|
const uint tid = INVOCATION_ID;
|
||||||
|
|
||||||
// Each thread handles a vec4, so 8 threads handle a block
|
// Each thread handles a vec4, so 8 threads handle a block
|
||||||
|
|
@ -45,11 +45,7 @@ void quantize() {
|
||||||
const uint ib = wgid * blocks_per_group + block_in_wg;
|
const uint ib = wgid * blocks_per_group + block_in_wg;
|
||||||
const uint iqs = tid % 8;
|
const uint iqs = tid % 8;
|
||||||
|
|
||||||
#ifndef QBLOCK_X4
|
#ifdef QBLOCK_X4
|
||||||
if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
const uint ibx4_outer = ib / 4;
|
const uint ibx4_outer = ib / 4;
|
||||||
const uint ibx4_inner = ib % 4;
|
const uint ibx4_inner = ib % 4;
|
||||||
|
|
||||||
|
|
@ -123,5 +119,9 @@ void quantize() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
quantize();
|
uint wgid = gl_WorkGroupID.x;
|
||||||
|
while (wgid < p.num_blocks) {
|
||||||
|
quantize(wgid);
|
||||||
|
wgid += gl_NumWorkGroups.x;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7539,6 +7539,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 576, 512, 576, {1,1}, {1,1}));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 576, 512, 576, {1,1}, {1,1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 1, 512, 32768, {1, 1}, {1, 1}));
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// test the mat-mat path for Metal
|
// test the mat-mat path for Metal
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue