vulkan: fix im2col overflowing maxworkgroupcount (#18180)
This commit is contained in:
parent
b365c3ff01
commit
fd05c51cec
|
|
@ -1261,6 +1261,7 @@ struct vk_op_im2col_push_constants {
|
|||
int32_t s0; int32_t s1;
|
||||
int32_t p0; int32_t p1;
|
||||
int32_t d0; int32_t d1;
|
||||
uint32_t batch_IC;
|
||||
};
|
||||
|
||||
struct vk_op_im2col_3d_push_constants {
|
||||
|
|
@ -5902,6 +5903,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
|
|||
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
|
||||
}
|
||||
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
|
||||
GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
|
||||
wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
|
||||
wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
|
||||
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
|
||||
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
|
||||
|
|
@ -9090,6 +9094,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
||||
|
||||
elements = { OW * KW * KH, OH, batch * IC };
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
} break;
|
||||
case GGML_OP_IM2COL_3D:
|
||||
{
|
||||
|
|
@ -10605,6 +10611,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
||||
|
||||
const uint32_t pelements = OW * KW * KH;
|
||||
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
||||
|
||||
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
|
||||
|
|
@ -10617,7 +10624,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
IC, IW, IH, OW, OH, KW, KH,
|
||||
pelements,
|
||||
IC * KH * KW,
|
||||
s0, s1, p0, p1, d0, d1,
|
||||
s0, s1, p0, p1, d0, d1, batch * IC
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ layout (push_constant) uniform parameter
|
|||
int s0; int s1;
|
||||
int p0; int p1;
|
||||
int d0; int d1;
|
||||
uint batch_IC;
|
||||
} p;
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
|
@ -34,12 +35,12 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
|||
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
void im2col(const uint y, const uint z) {
|
||||
const uint gidx = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint oh = gl_GlobalInvocationID.y;
|
||||
const uint batch = gl_GlobalInvocationID.z / p.IC;
|
||||
const uint ic = gl_GlobalInvocationID.z % p.IC;
|
||||
const uint oh = y;
|
||||
const uint batch = z / p.IC;
|
||||
const uint ic = z % p.IC;
|
||||
|
||||
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
|
||||
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
|
||||
|
|
@ -101,3 +102,15 @@ void main() {
|
|||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
uint y = gl_GlobalInvocationID.y;
|
||||
while (y < p.OH) {
|
||||
uint z = gl_GlobalInvocationID.z;
|
||||
while (z < p.batch_IC) {
|
||||
im2col(y, z);
|
||||
z += gl_NumWorkGroups.z;
|
||||
}
|
||||
y += gl_NumWorkGroups.y;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6930,6 +6930,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 1536, 729}, {2, 2, 1536, 4096}, 1, 1, 0, 0, 1, 1, true));
|
||||
|
||||
// im2col 3D
|
||||
test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
|
||||
|
|
|
|||
Loading…
Reference in New Issue