vulkan: optimize im2col (#21713)
* vulkan: improve im2col memory write layout * cap workgroups * minimal device tuning * use vendor_id instead of subgroup size
This commit is contained in:
parent
7e72b38bc1
commit
b3d758750a
|
|
@ -1394,7 +1394,7 @@ struct vk_op_im2col_push_constants {
|
|||
uint32_t IW; uint32_t IH;
|
||||
uint32_t OW; uint32_t OH;
|
||||
uint32_t KW; uint32_t KH;
|
||||
uint32_t pelements;
|
||||
uint32_t OH_batch;
|
||||
uint32_t CHW;
|
||||
int32_t s0; int32_t s1;
|
||||
int32_t p0; int32_t p1;
|
||||
|
|
@ -10064,7 +10064,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
|
||||
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
||||
|
||||
elements = { OW * KW * KH, OH, batch * IC };
|
||||
const uint32_t CHW = IC * KH * KW;
|
||||
// Cap X workgroups to limit concurrent IC channel reads.
|
||||
// The shader loops over X to cover the full CHW dimension.
|
||||
// AMD prefers a lower limit
|
||||
const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u;
|
||||
const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW));
|
||||
elements = { x_elements, OW, OH * batch };
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
} break;
|
||||
|
|
@ -11727,7 +11733,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
||||
|
||||
const uint32_t pelements = OW * KW * KH;
|
||||
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
||||
|
||||
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
|
|
@ -11739,7 +11744,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
dst_addr,
|
||||
batch_offset, offset_delta,
|
||||
IC, IW, IH, OW, OH, KW, KH,
|
||||
pelements,
|
||||
OH * batch,
|
||||
IC * KH * KW,
|
||||
s0, s1, p0, p1, d0, d1, batch * IC
|
||||
});
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ layout (push_constant) uniform parameter
|
|||
uint IW; uint IH;
|
||||
uint OW; uint OH;
|
||||
uint KW; uint KH;
|
||||
uint pelements;
|
||||
uint OH_batch;
|
||||
uint CHW;
|
||||
int s0; int s1;
|
||||
int p0; int p1;
|
||||
|
|
@ -34,82 +34,60 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
|||
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
|
||||
#endif
|
||||
|
||||
void im2col(const uint y, const uint z) {
|
||||
const uint gidx = gl_GlobalInvocationID.x;
|
||||
void im2col(const uint ow, const uint z_idx) {
|
||||
const uint oh = z_idx % p.OH;
|
||||
const uint batch_idx = z_idx / p.OH;
|
||||
|
||||
const uint oh = y;
|
||||
const uint batch = z / p.IC;
|
||||
const uint ic = z % p.IC;
|
||||
const uint gidx = gl_LocalInvocationID.x;
|
||||
const uint src_batch = batch_idx * p.batch_offset;
|
||||
const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW;
|
||||
|
||||
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
|
||||
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
|
||||
const int oh_s1 = int(oh) * p.s1;
|
||||
const uint ksize = p.OW * p.KH;
|
||||
const uint KHKW = p.KH * p.KW;
|
||||
|
||||
const uint base_linear_idx = gidx * NUM_ITER;
|
||||
uint wg_x = gl_WorkGroupID.x;
|
||||
do {
|
||||
const uint wg_offset = wg_x * 512;
|
||||
|
||||
uint current_kx = base_linear_idx / ksize;
|
||||
const uint rem = base_linear_idx - (current_kx * ksize);
|
||||
uint current_ky = rem / p.OW;
|
||||
uint current_ix = rem % p.OW;
|
||||
[[unroll]] for (uint i = 0; i < NUM_ITER; ++i) {
|
||||
const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE;
|
||||
|
||||
A_TYPE values[NUM_ITER];
|
||||
BDA_OFFSET_T offset_dst[NUM_ITER];
|
||||
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||
values[idx] = A_TYPE(0);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||
|
||||
const uint linear_idx = base_linear_idx + idx;
|
||||
|
||||
if (linear_idx >= p.pelements) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
|
||||
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
|
||||
|
||||
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
|
||||
|
||||
if ((iih < p.IH) && (iiw < p.IW)) {
|
||||
values[idx] = data_a[src_base + iih * p.IW + iiw];
|
||||
}
|
||||
|
||||
if (++current_ix == p.OW) {
|
||||
current_ix = 0;
|
||||
if (++current_ky == p.KH) {
|
||||
current_ky = 0;
|
||||
current_kx++;
|
||||
if (chw_idx >= p.CHW) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||
const uint ic = chw_idx / KHKW;
|
||||
const uint rem = chw_idx - ic * KHKW;
|
||||
const uint ky = rem / p.KW;
|
||||
const uint kx = rem - ky * p.KW;
|
||||
|
||||
const uint linear_idx = base_linear_idx + idx;
|
||||
const uint iiw = ow * p.s0 + kx * p.d0 - p.p0;
|
||||
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
|
||||
|
||||
if (linear_idx >= p.pelements) {
|
||||
continue;
|
||||
}
|
||||
A_TYPE val = A_TYPE(0);
|
||||
if (iih < p.IH && iiw < p.IW) {
|
||||
val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw];
|
||||
}
|
||||
|
||||
#if BDA
|
||||
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
|
||||
dst_addr.d = D_TYPE(values[idx]);
|
||||
D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx));
|
||||
out_ptr.d = D_TYPE(val);
|
||||
#else
|
||||
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
|
||||
data_d[dst_row + chw_idx] = D_TYPE(val);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
wg_x += gl_NumWorkGroups.x;
|
||||
} while (wg_x * 512 < p.CHW);
|
||||
}
|
||||
|
||||
void main() {
|
||||
uint y = gl_GlobalInvocationID.y;
|
||||
while (y < p.OH) {
|
||||
uint ow = gl_GlobalInvocationID.y;
|
||||
while (ow < p.OW) {
|
||||
uint z = gl_GlobalInvocationID.z;
|
||||
while (z < p.batch_IC) {
|
||||
im2col(y, z);
|
||||
while (z < p.OH_batch) {
|
||||
im2col(ow, z);
|
||||
z += gl_NumWorkGroups.z;
|
||||
}
|
||||
y += gl_NumWorkGroups.y;
|
||||
ow += gl_NumWorkGroups.y;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue