From 61bde8e21f4a1f9a98c9205831ca3e55457b4c78 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 2 Dec 2025 12:22:04 -0600 Subject: [PATCH] vulkan: Reduce temporary memory usage for TOP_K (#17623) - Compute row size for the temp buffer based on the output of the first pass. - Update shader addressing math to use the output row size - Pass the output row size as "ncols_output", what used to be "ncols_output" is now "k" For the common case of K=40 and src0=(200000,1,1,1), this reduces the temporary buffer from about 3.2MB to 500KB. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 39 +++++++++++++------ .../vulkan-shaders/topk_argsort.comp | 19 +++++---- .../vulkan-shaders/topk_nary_search.comp | 23 ++++++----- 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 95966ce1d8..f917a745d5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1227,6 +1227,7 @@ struct vk_op_topk_push_constants { uint32_t orig_ncols; uint32_t ncols_input; uint32_t ncols_output; + uint32_t k; uint32_t nrows; uint32_t first_pass; uint32_t last_pass; @@ -1673,6 +1674,14 @@ class vk_perf_logger { timings[name.str()].push_back(time); return; } + if (node->op == GGML_OP_TOP_K) { + std::stringstream name; + name << ggml_op_name(node->op) << + " K=" << node->ne[0] << + " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")"; + timings[name.str()].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -10345,17 +10354,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons uint32_t nrows = ggml_nrows(src0); uint32_t k = dst->ne[0]; - vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 }; + vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 }; - // Reserve space for ivec2 per element, double buffered - const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int); - const size_t x_sz = dbl_buf_size * 2; - uint32_t dbl_buf_index = 0; - - if (ctx->prealloc_size_x < x_sz) { - ctx->prealloc_size_x = x_sz; - ggml_vk_preallocate_buffers(ctx, subctx); - } if (ctx->prealloc_x_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } @@ -10370,8 +10370,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // largest elements. Repeat until we have the top K elements. // Need to do at least one iteration to write out the results. bool done_one_iter = false; + uint32_t dbl_buf_index = 0; + size_t dbl_buf_size; while (num_elements > k || !done_one_iter) { - done_one_iter = true; // Prefer going as small as num_topk_pipelines - 3 for perf reasons. // But if K is larger, then we need a larger workgroup @@ -10411,6 +10412,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // Number of elements remaining after this pass uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]); + pc2.ncols_output = num_dst_elements; + + if (!done_one_iter) { + // Reserve space for ivec2 per element, double buffered + // K per workgroup per row + dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int); + dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const size_t x_sz = dbl_buf_size * 2; + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + } + vk_subbuffer src_buf; vk_subbuffer dst_buf; @@ -10436,6 +10452,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons if (num_elements > k) { ggml_vk_sync_buffers(ctx, subctx); } + done_one_iter = true; } ctx->prealloc_x_need_sync = true; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp index cd858b7d32..49d4ab8e7c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp @@ -19,6 +19,7 @@ layout (push_constant) uniform parameter { uint orig_ncols; uint ncols_input; uint ncols_output; + uint k; uint nrows; uint first_pass; uint last_pass; @@ -36,7 +37,7 @@ void topk(bool needs_bounds_check, const uint row) { const uint row_offset = row * p.ncols_input; dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); } else { - const uint row_offset = row * p.orig_ncols; + const uint row_offset = row * p.ncols_input; dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x]; } } else { @@ -44,7 +45,7 @@ void topk(bool needs_bounds_check, const uint row) { } barrier(); - if (p.ncols_output == 1) { + if (p.k == 1) { // Fast path for single output - just do a max reduction [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { if (col < s) { @@ -84,13 +85,17 @@ void topk(bool needs_bounds_check, const uint row) { } } - if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (col < p.k) { if (p.last_pass != 0) { - const uint row_offset = row * p.ncols_output; - data_d[row_offset + col] = dst_row[col].x; + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + col] = dst_row[col].x; + } } else { - const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; - data_t[row_offset + col] = dst_row[col]; + if (gl_WorkGroupID.x * p.k + col < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + col] = dst_row[col]; + } } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp index c902e60237..f794285ee1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp @@ -25,6 +25,7 @@ layout (push_constant) uniform parameter { uint orig_ncols; uint ncols_input; uint ncols_output; + uint k; uint nrows; uint first_pass; uint last_pass; @@ -60,7 +61,7 @@ void topk(const uint row) { const uint row_offset = row * p.ncols_input; dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); } else { - const uint row_offset = row * p.orig_ncols; + const uint row_offset = row * p.ncols_input; dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x]; } } else { @@ -68,7 +69,7 @@ void topk(const uint row) { } barrier(); - if (p.ncols_output == 1) { + if (p.k == 1) { // Fast path for single output - just do a max reduction [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { if (tid < s) { @@ -98,7 +99,7 @@ void topk(const uint row) { uint range_max = 0xFF800000; // How many are above the current range, and how many we need to find. uint total = 0; - uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); + uint limit = min(p.k, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); while (mask != 0) { barrier(); @@ -139,7 +140,7 @@ void topk(const uint row) { range_max = range_min + ((min_idx + 1) << shift); range_min = range_min + (min_idx << shift); - if (total == p.ncols_output) { + if (total == p.k) { break; } total -= counts[min_idx]; @@ -179,13 +180,17 @@ void topk(const uint row) { barrier(); } - if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (tid < p.k) { if (p.last_pass != 0) { - const uint row_offset = row * p.ncols_output; - data_d[row_offset + tid] = dst_row[tid].x; + if (gl_GlobalInvocationID.x < p.ncols_input) { + const uint row_offset = row * p.k; + data_d[row_offset + tid] = dst_row[tid].x; + } } else { - const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; - data_t[row_offset + tid] = dst_row[tid]; + if (gl_WorkGroupID.x * p.k + tid < p.ncols_output) { + const uint row_offset = row * p.ncols_output + gl_WorkGroupID.x * p.k; + data_t[row_offset + tid] = dst_row[tid]; + } } } }