vulkan: Reduce temporary memory usage for TOP_K (#17623)
- Compute row size for the temp buffer based on the output of the first pass. - Update shader addressing math to use the output row size - Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k" For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer from about 3.2MB to 500KB.
This commit is contained in:
parent
e251e5ebbe
commit
61bde8e21f
|
|
@ -1227,6 +1227,7 @@ struct vk_op_topk_push_constants {
|
|||
uint32_t orig_ncols;
|
||||
uint32_t ncols_input;
|
||||
uint32_t ncols_output;
|
||||
uint32_t k;
|
||||
uint32_t nrows;
|
||||
uint32_t first_pass;
|
||||
uint32_t last_pass;
|
||||
|
|
@ -1673,6 +1674,14 @@ class vk_perf_logger {
|
|||
timings[name.str()].push_back(time);
|
||||
return;
|
||||
}
|
||||
if (node->op == GGML_OP_TOP_K) {
|
||||
std::stringstream name;
|
||||
name << ggml_op_name(node->op) <<
|
||||
" K=" << node->ne[0] <<
|
||||
" (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
|
||||
timings[name.str()].push_back(time);
|
||||
return;
|
||||
}
|
||||
timings[ggml_op_name(node->op)].push_back(time);
|
||||
}
|
||||
private:
|
||||
|
|
@ -10345,17 +10354,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|||
uint32_t nrows = ggml_nrows(src0);
|
||||
uint32_t k = dst->ne[0];
|
||||
|
||||
vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
|
||||
vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };
|
||||
|
||||
// Reserve space for ivec2 per element, double buffered
|
||||
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
|
||||
const size_t x_sz = dbl_buf_size * 2;
|
||||
uint32_t dbl_buf_index = 0;
|
||||
|
||||
if (ctx->prealloc_size_x < x_sz) {
|
||||
ctx->prealloc_size_x = x_sz;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
if (ctx->prealloc_x_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
|
|
@ -10370,8 +10370,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|||
// largest elements. Repeat until we have the top K elements.
|
||||
// Need to do at least one iteration to write out the results.
|
||||
bool done_one_iter = false;
|
||||
uint32_t dbl_buf_index = 0;
|
||||
size_t dbl_buf_size;
|
||||
while (num_elements > k || !done_one_iter) {
|
||||
done_one_iter = true;
|
||||
|
||||
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
|
||||
// But if K is larger, then we need a larger workgroup
|
||||
|
|
@ -10411,6 +10412,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|||
// Number of elements remaining after this pass
|
||||
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
|
||||
|
||||
pc2.ncols_output = num_dst_elements;
|
||||
|
||||
if (!done_one_iter) {
|
||||
// Reserve space for ivec2 per element, double buffered
|
||||
// K per workgroup per row
|
||||
dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);
|
||||
dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t x_sz = dbl_buf_size * 2;
|
||||
|
||||
if (ctx->prealloc_size_x < x_sz) {
|
||||
ctx->prealloc_size_x = x_sz;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
}
|
||||
|
||||
vk_subbuffer src_buf;
|
||||
vk_subbuffer dst_buf;
|
||||
|
||||
|
|
@ -10436,6 +10452,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|||
if (num_elements > k) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
done_one_iter = true;
|
||||
}
|
||||
ctx->prealloc_x_need_sync = true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ layout (push_constant) uniform parameter {
|
|||
uint orig_ncols;
|
||||
uint ncols_input;
|
||||
uint ncols_output;
|
||||
uint k;
|
||||
uint nrows;
|
||||
uint first_pass;
|
||||
uint last_pass;
|
||||
|
|
@ -36,7 +37,7 @@ void topk(bool needs_bounds_check, const uint row) {
|
|||
const uint row_offset = row * p.ncols_input;
|
||||
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
||||
} else {
|
||||
const uint row_offset = row * p.orig_ncols;
|
||||
const uint row_offset = row * p.ncols_input;
|
||||
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
|
||||
}
|
||||
} else {
|
||||
|
|
@ -44,7 +45,7 @@ void topk(bool needs_bounds_check, const uint row) {
|
|||
}
|
||||
barrier();
|
||||
|
||||
if (p.ncols_output == 1) {
|
||||
if (p.k == 1) {
|
||||
// Fast path for single output - just do a max reduction
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
||||
if (col < s) {
|
||||
|
|
@ -84,13 +85,17 @@ void topk(bool needs_bounds_check, const uint row) {
|
|||
}
|
||||
}
|
||||
|
||||
if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
|
||||
if (col < p.k) {
|
||||
if (p.last_pass != 0) {
|
||||
const uint row_offset = row * p.ncols_output;
|
||||
data_d[row_offset + col] = dst_row[col].x;
|
||||
if (gl_GlobalInvocationID.x < p.ncols_input) {
|
||||
const uint row_offset = row * p.k;
|
||||
data_d[row_offset + col] = dst_row[col].x;
|
||||
}
|
||||
} else {
|
||||
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
|
||||
data_t[row_offset + col] = dst_row[col];
|
||||
if (gl_WorkGroupID.x * p.k + col < p.ncols_output) {
|
||||
const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
|
||||
data_t[row_offset + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
|
|||
uint orig_ncols;
|
||||
uint ncols_input;
|
||||
uint ncols_output;
|
||||
uint k;
|
||||
uint nrows;
|
||||
uint first_pass;
|
||||
uint last_pass;
|
||||
|
|
@ -60,7 +61,7 @@ void topk(const uint row) {
|
|||
const uint row_offset = row * p.ncols_input;
|
||||
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
||||
} else {
|
||||
const uint row_offset = row * p.orig_ncols;
|
||||
const uint row_offset = row * p.ncols_input;
|
||||
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
|
||||
}
|
||||
} else {
|
||||
|
|
@ -68,7 +69,7 @@ void topk(const uint row) {
|
|||
}
|
||||
barrier();
|
||||
|
||||
if (p.ncols_output == 1) {
|
||||
if (p.k == 1) {
|
||||
// Fast path for single output - just do a max reduction
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
||||
if (tid < s) {
|
||||
|
|
@ -98,7 +99,7 @@ void topk(const uint row) {
|
|||
uint range_max = 0xFF800000;
|
||||
// How many are above the current range, and how many we need to find.
|
||||
uint total = 0;
|
||||
uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
|
||||
uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
|
||||
|
||||
while (mask != 0) {
|
||||
barrier();
|
||||
|
|
@ -139,7 +140,7 @@ void topk(const uint row) {
|
|||
range_max = range_min + ((min_idx + 1) << shift);
|
||||
range_min = range_min + (min_idx << shift);
|
||||
|
||||
if (total == p.ncols_output) {
|
||||
if (total == p.k) {
|
||||
break;
|
||||
}
|
||||
total -= counts[min_idx];
|
||||
|
|
@ -179,13 +180,17 @@ void topk(const uint row) {
|
|||
barrier();
|
||||
}
|
||||
|
||||
if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
|
||||
if (tid < p.k) {
|
||||
if (p.last_pass != 0) {
|
||||
const uint row_offset = row * p.ncols_output;
|
||||
data_d[row_offset + tid] = dst_row[tid].x;
|
||||
if (gl_GlobalInvocationID.x < p.ncols_input) {
|
||||
const uint row_offset = row * p.k;
|
||||
data_d[row_offset + tid] = dst_row[tid].x;
|
||||
}
|
||||
} else {
|
||||
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
|
||||
data_t[row_offset + tid] = dst_row[tid];
|
||||
if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) {
|
||||
const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k;
|
||||
data_t[row_offset + tid] = dst_row[tid];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue