vulkan: optimize im2col (#21713)

* vulkan: improve im2col memory write layout

* cap workgroups

* minimal device tuning

* use vendor_id instead of subgroup size
This commit is contained in:
Ruben Ortlam 2026-04-15 19:04:51 +02:00 committed by GitHub
parent 7e72b38bc1
commit b3d758750a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 63 deletions

View File

@ -1394,7 +1394,7 @@ struct vk_op_im2col_push_constants {
uint32_t IW; uint32_t IH;
uint32_t OW; uint32_t OH;
uint32_t KW; uint32_t KH;
uint32_t pelements;
uint32_t OH_batch;
uint32_t CHW;
int32_t s0; int32_t s1;
int32_t p0; int32_t p1;
@ -10064,7 +10064,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
elements = { OW * KW * KH, OH, batch * IC };
const uint32_t CHW = IC * KH * KW;
// Cap X workgroups to limit concurrent IC channel reads.
// The shader loops over X to cover the full CHW dimension.
// AMD prefers a lower limit
const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u;
const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW));
elements = { x_elements, OW, OH * batch };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
} break;
@ -11727,7 +11733,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const uint32_t pelements = OW * KW * KH;
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
@ -11739,7 +11744,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
dst_addr,
batch_offset, offset_delta,
IC, IW, IH, OW, OH, KW, KH,
pelements,
OH * batch,
IC * KH * KW,
s0, s1, p0, p1, d0, d1, batch * IC
});

View File

@ -13,7 +13,7 @@ layout (push_constant) uniform parameter
uint IW; uint IH;
uint OW; uint OH;
uint KW; uint KH;
uint pelements;
uint OH_batch;
uint CHW;
int s0; int s1;
int p0; int p1;
@ -34,82 +34,60 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void im2col(const uint y, const uint z) {
const uint gidx = gl_GlobalInvocationID.x;
void im2col(const uint ow, const uint z_idx) {
const uint oh = z_idx % p.OH;
const uint batch_idx = z_idx / p.OH;
const uint oh = y;
const uint batch = z / p.IC;
const uint ic = z % p.IC;
const uint gidx = gl_LocalInvocationID.x;
const uint src_batch = batch_idx * p.batch_offset;
const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1;
const uint ksize = p.OW * p.KH;
const uint KHKW = p.KH * p.KW;
const uint base_linear_idx = gidx * NUM_ITER;
uint wg_x = gl_WorkGroupID.x;
do {
const uint wg_offset = wg_x * 512;
uint current_kx = base_linear_idx / ksize;
const uint rem = base_linear_idx - (current_kx * ksize);
uint current_ky = rem / p.OW;
uint current_ix = rem % p.OW;
[[unroll]] for (uint i = 0; i < NUM_ITER; ++i) {
const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE;
A_TYPE values[NUM_ITER];
BDA_OFFSET_T offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
values[idx] = A_TYPE(0);
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint linear_idx = base_linear_idx + idx;
if (linear_idx >= p.pelements) {
continue;
}
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
if ((iih < p.IH) && (iiw < p.IW)) {
values[idx] = data_a[src_base + iih * p.IW + iiw];
}
if (++current_ix == p.OW) {
current_ix = 0;
if (++current_ky == p.KH) {
current_ky = 0;
current_kx++;
if (chw_idx >= p.CHW) {
return;
}
}
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint ic = chw_idx / KHKW;
const uint rem = chw_idx - ic * KHKW;
const uint ky = rem / p.KW;
const uint kx = rem - ky * p.KW;
const uint linear_idx = base_linear_idx + idx;
const uint iiw = ow * p.s0 + kx * p.d0 - p.p0;
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
if (linear_idx >= p.pelements) {
continue;
}
A_TYPE val = A_TYPE(0);
if (iih < p.IH && iiw < p.IW) {
val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw];
}
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
dst_addr.d = D_TYPE(values[idx]);
D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx));
out_ptr.d = D_TYPE(val);
#else
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
data_d[dst_row + chw_idx] = D_TYPE(val);
#endif
}
}
wg_x += gl_NumWorkGroups.x;
} while (wg_x * 512 < p.CHW);
}
void main() {
uint y = gl_GlobalInvocationID.y;
while (y < p.OH) {
uint ow = gl_GlobalInvocationID.y;
while (ow < p.OW) {
uint z = gl_GlobalInvocationID.z;
while (z < p.batch_IC) {
im2col(y, z);
while (z < p.OH_batch) {
im2col(ow, z);
z += gl_NumWorkGroups.z;
}
y += gl_NumWorkGroups.y;
ow += gl_NumWorkGroups.y;
}
}