vulkan: fix im2col overflowing maxworkgroupcount (#18180)

This commit is contained in:
Jeff Bolz 2025-12-21 03:32:58 -06:00 committed by GitHub
parent b365c3ff01
commit fd05c51cec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 5 deletions

View File

@ -1261,6 +1261,7 @@ struct vk_op_im2col_push_constants {
int32_t s0; int32_t s1; int32_t s0; int32_t s1;
int32_t p0; int32_t p1; int32_t p0; int32_t p1;
int32_t d0; int32_t d1; int32_t d0; int32_t d1;
uint32_t batch_IC;
}; };
struct vk_op_im2col_3d_push_constants { struct vk_op_im2col_3d_push_constants {
@ -5902,6 +5903,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
} }
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size()); GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
@ -9090,6 +9094,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t batch = src1->ne[is_2D ? 3 : 2]; const uint32_t batch = src1->ne[is_2D ? 3 : 2];
elements = { OW * KW * KH, OH, batch * IC }; elements = { OW * KW * KH, OH, batch * IC };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
} break; } break;
case GGML_OP_IM2COL_3D: case GGML_OP_IM2COL_3D:
{ {
@ -10605,6 +10611,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const uint32_t pelements = OW * KW * KH; const uint32_t pelements = OW * KW * KH;
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
const vk_buffer d_buf = d_buf_ctx->dev_buffer; const vk_buffer d_buf = d_buf_ctx->dev_buffer;
@ -10617,7 +10624,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
IC, IW, IH, OW, OH, KW, KH, IC, IW, IH, OW, OH, KW, KH,
pelements, pelements,
IC * KH * KW, IC * KH * KW,
s0, s1, p0, p1, d0, d1, s0, s1, p0, p1, d0, d1, batch * IC
}); });
} }

View File

@ -19,6 +19,7 @@ layout (push_constant) uniform parameter
int s0; int s1; int s0; int s1;
int p0; int p1; int p0; int p1;
int d0; int d1; int d0; int d1;
uint batch_IC;
} p; } p;
layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(constant_id = 0) const uint BLOCK_SIZE = 32;
@ -34,12 +35,12 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (buffer_reference) buffer D_ptr {D_TYPE d;}; layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif #endif
void main() { void im2col(const uint y, const uint z) {
const uint gidx = gl_GlobalInvocationID.x; const uint gidx = gl_GlobalInvocationID.x;
const uint oh = gl_GlobalInvocationID.y; const uint oh = y;
const uint batch = gl_GlobalInvocationID.z / p.IC; const uint batch = z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC; const uint ic = z % p.IC;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset; const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
@ -101,3 +102,15 @@ void main() {
#endif #endif
} }
} }
void main() {
uint y = gl_GlobalInvocationID.y;
while (y < p.OH) {
uint z = gl_GlobalInvocationID.z;
while (z < p.batch_IC) {
im2col(y, z);
z += gl_NumWorkGroups.z;
}
y += gl_NumWorkGroups.y;
}
}

View File

@ -6930,6 +6930,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 1536, 729}, {2, 2, 1536, 4096}, 1, 1, 0, 0, 1, 1, true));
// im2col 3D // im2col 3D
test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32)); test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));