From 879d673759181987300a29989478c0f36c97e97b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 26 Nov 2025 09:45:43 -0600 Subject: [PATCH] vulkan: Implement top-k (#17418) * vulkan: Implement top-k Each pass launches workgroups that each sort 2^N elements (where N is usually 7-10) and discards all but the top K. Repeat until only K are left. And there's a fast path when K==1 to just find the max value rather than sorting. * fix pipeline selection * vulkan: Add N-ary search algorithm for topk * microoptimizations --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 150 +++++++++++++ .../vulkan-shaders/topk_argsort.comp | 113 ++++++++++ .../vulkan-shaders/topk_nary_search.comp | 199 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 3 + tests/test-backend-ops.cpp | 16 +- 5 files changed, 480 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 05e585ebff..9c97f0a6fa 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -409,6 +409,7 @@ enum shader_reduction_mode { // argsort pipelines for up to 1<<10 invocations per workgroup static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t num_topk_moe_pipelines = 10; +static constexpr uint32_t num_topk_pipelines = 11; static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, @@ -515,6 +516,7 @@ struct vk_device_struct { bool single_queue; bool support_async; uint32_t subgroup_size; + uint32_t subgroup_size_log2; uint32_t shader_core_count; bool uma; bool prefer_host_memory; @@ -704,6 +706,7 @@ struct vk_device_struct { vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; + vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_argmax_f32; @@ -1205,6 +1208,15 @@ struct vk_op_argsort_push_constants { uint32_t inner_end; }; +struct vk_op_topk_push_constants { + uint32_t orig_ncols; + uint32_t ncols_input; + uint32_t ncols_output; + uint32_t nrows; + uint32_t first_pass; + uint32_t last_pass; +}; + struct vk_op_im2col_push_constants { uint64_t dst_addr; uint32_t batch_offset; uint32_t offset_delta; @@ -3965,6 +3977,23 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true); } + for (uint32_t i = 0; i < num_topk_pipelines; ++i) { + const uint32_t BLOCK_SIZE = 1u << i; + const uint32_t NCOLS_PADDED_LOG2 = i; + if (i <= device->max_workgroup_size_log2) { + uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE + + sizeof(int) * device->subgroup_size + + 2 * sizeof(int) + + (BLOCK_SIZE / device->subgroup_size) * sizeof(int); + if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot && + nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size); + } else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) { + ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true); + } + } + } + ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -4336,6 +4365,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); device->subgroup_size = subgroup_props.subgroupSize; + device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size))); device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; if (sm_builtins) { device->shader_core_count = sm_props.shaderSMCount; @@ -10143,6 +10173,104 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c } } +static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + uint32_t ncols = src0->ne[0]; + uint32_t nrows = ggml_nrows(src0); + uint32_t k = dst->ne[0]; + + vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 }; + + // Reserve space for ivec2 per element, double buffered + const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int); + const size_t x_sz = dbl_buf_size * 2; + uint32_t dbl_buf_index = 0; + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + std::array elements; + elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = 1; + + uint32_t num_elements = ncols; + + // Each iteration reduces a workgroup's worth of elements down to the K + // largest elements. Repeat until we have the top K elements. + // Need to do at least one iteration to write out the results. + bool done_one_iter = false; + while (num_elements > k || !done_one_iter) { + done_one_iter = true; + + // Prefer going as small as num_topk_pipelines - 3 for perf reasons. + // But if K is larger, then we need a larger workgroup + uint32_t max_pipeline = num_topk_pipelines - 3; + uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1; + // require full subgroup + min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2); + + uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements))); + pipeline_idx = std::min(pipeline_idx, max_pipeline); + pipeline_idx = std::max(pipeline_idx, min_pipeline); + + if (num_elements > (1u << pipeline_idx)) { + // If we could finish on this loop iteration (i.e. a single workgroup) + // then do so. It's better than the overhead of another pass. + for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) { + if (num_elements <= (1u << i)) { + pipeline_idx = i; + break; + } + } + } + + vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx]; + // If the device doesn't support a pipeline this large, use smaller + while (!pipeline) { + pipeline_idx--; + GGML_ASSERT(pipeline_idx >= min_pipeline); + pipeline = ctx->device->pipeline_topk_f32[pipeline_idx]; + } + + vk_op_topk_push_constants pc2 = pc; + pc2.ncols_input = num_elements; + + // Number of elements remaining after this pass + uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]); + + vk_subbuffer src_buf; + vk_subbuffer dst_buf; + + if (num_elements == ncols) { + pc2.first_pass = 1; + src_buf = ggml_vk_tensor_subbuffer(ctx, src0); + } else { + src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size }; + } + if (num_dst_elements == k) { + pc2.last_pass = 1; + dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + } else { + dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size }; + } + + elements[0] = num_elements; + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements); + num_elements = num_dst_elements; + dbl_buf_index ^= 1; + if (num_elements > k) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + ctx->prealloc_x_need_sync = true; +} + static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p); @@ -11755,6 +11883,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ggml_vk_argsort(ctx, compute_ctx, src0, node); } + break; + case GGML_OP_TOP_K: + ggml_vk_topk(ctx, compute_ctx, src0, node); + break; case GGML_OP_SUM: ggml_vk_sum(ctx, compute_ctx, src0, node); @@ -13787,6 +13919,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return op->ne[0] <= (1 << device->max_workgroup_size_log2); } } + case GGML_OP_TOP_K: + { + if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { + return false; + } + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + // We could potentially support larger, using argsort to sort the + // whole thing. Not clear if this is needed. + uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1; + if (min_pipeline >= num_topk_pipelines || + !device->pipeline_topk_f32[min_pipeline]) { + return false; + } + } + return true; case GGML_OP_UPSCALE: case GGML_OP_ACC: case GGML_OP_CONCAT: @@ -14459,6 +14607,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ARGSORT) { tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_TOP_K) { + tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]); } else if (tensor->op == GGML_OP_SUM) { tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SUM_ROWS) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp new file mode 100644 index 0000000000..cd858b7d32 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp @@ -0,0 +1,113 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// Input can either be the source (A) or intermediate values (S). +// Similarly, output can be either destination (D) or intermediate values (S). +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 0) readonly buffer S {ivec2 data_s[];}; +layout (binding = 1) writeonly buffer D {int data_d[];}; +layout (binding = 1) writeonly buffer T {ivec2 data_t[];}; + +layout (push_constant) uniform parameter { + uint orig_ncols; + uint ncols_input; + uint ncols_output; + uint nrows; + uint first_pass; + uint last_pass; +} p; + +// pairs of (gid, value) +shared ivec2 dst_row[BLOCK_SIZE]; + +void topk(bool needs_bounds_check, const uint row) { + const int col = int(gl_LocalInvocationID.x); + + // initialize indices + if (gl_GlobalInvocationID.x < p.ncols_input) { + if (p.first_pass != 0) { + const uint row_offset = row * p.ncols_input; + dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); + } else { + const uint row_offset = row * p.orig_ncols; + dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x]; + } + } else { + dst_row[col] = ivec2(p.orig_ncols, 0); + } + barrier(); + + if (p.ncols_output == 1) { + // Fast path for single output - just do a max reduction + [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { + if (col < s) { + ivec2 a = dst_row[col]; + ivec2 b = dst_row[col + s]; + if (a.x >= p.orig_ncols || + b.x < p.orig_ncols && b.y > a.y) { + dst_row[col] = b; + } + } + barrier(); + } + } else { + // bitonic sort on this group of elements + uint num_outer_loop_iters = NCOLS_PADDED_LOG2; + for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { + uint num_inner_loop_iters = outer_idx + 1; + for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { + const int ixj = int(col ^ j); + + int idx_0 = (col & k) == 0 ? col : ixj; + int idx_1 = (col & k) == 0 ? ixj : col; + + ivec2 sh_idx_0 = dst_row[idx_0]; + ivec2 sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) { + dst_row[idx_0] = sh_idx_1; + dst_row[idx_1] = sh_idx_0; + } + + barrier(); + } + } + } + + if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (p.last_pass != 0) { + const uint row_offset = row * p.ncols_output; + data_d[row_offset + col] = dst_row[col].x; + } else { + const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; + data_t[row_offset + col] = dst_row[col]; + } + } +} + +void main() { + // Fast path for fully occupied workgroups + if ((p.ncols_input % BLOCK_SIZE) == 0) { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } else { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp new file mode 100644 index 0000000000..c902e60237 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp @@ -0,0 +1,199 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_debug_printf : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int SUBGROUP_SIZE = 32; +layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// Input can either be the source (A) or intermediate values (S). +// Similarly, output can be either destination (D) or intermediate values (S). +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 0) readonly buffer S {ivec2 data_s[];}; +layout (binding = 1) writeonly buffer D {int data_d[];}; +layout (binding = 1) writeonly buffer T {ivec2 data_t[];}; + +layout (push_constant) uniform parameter { + uint orig_ncols; + uint ncols_input; + uint ncols_output; + uint nrows; + uint first_pass; + uint last_pass; +} p; + +// pairs of (gid, value) +shared ivec2 dst_row[BLOCK_SIZE]; + +shared int counts[SUBGROUP_SIZE]; +shared int sh_min_idx; +shared uint sh_total; +shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE]; + +// Map float values to uint such that comparisons still work. +// Positive values set the high bit, negative values are inverted. +// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places. +uint f2ui(float x) { + uint y = floatBitsToUint(x); + if ((y & 0x80000000) != 0) { + y ^= ~0; + } else { + y |= 0x80000000; + } + return y; +} + +void topk(const uint row) { + const int tid = int(gl_LocalInvocationID.x); + + // initialize indices + if (gl_GlobalInvocationID.x < p.ncols_input) { + if (p.first_pass != 0) { + const uint row_offset = row * p.ncols_input; + dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x])); + } else { + const uint row_offset = row * p.orig_ncols; + dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x]; + } + } else { + dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf + } + barrier(); + + if (p.ncols_output == 1) { + // Fast path for single output - just do a max reduction + [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) { + if (tid < s) { + ivec2 a = dst_row[tid]; + ivec2 b = dst_row[tid + s]; + if (a.x >= p.orig_ncols || + b.x < p.orig_ncols && b.y > a.y) { + dst_row[tid] = b; + } + } + barrier(); + } + } else { + // Do an N-ary search to find the K-th largest value. + // We remap the float values to be comparable as unsigned integers, + // and split the range into 2^N smaller ranges where N is the + // subgroup size. Count how many values are in each range, if the K-th + // largest value is in the middle of one of thee ranges then repeat + // and split again. + + // Mask is the current set of bits we're searching. Shift is the LSB index. + int shift = 32 - SUBGROUP_SIZE_LOG2; + uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift; + + // The current range. + uint range_min = 0; + uint range_max = 0xFF800000; + // How many are above the current range, and how many we need to find. + uint total = 0; + uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE); + + while (mask != 0) { + barrier(); + // Initialize bucket counts to zero. + if (tid < SUBGROUP_SIZE) { + counts[tid] = 0; + } + barrier(); + // Count how many values are in each bucket. + if (tid < p.ncols_input) { + float y = intBitsToFloat(dst_row[tid].y); + uint fy = f2ui(y); + if (fy >= range_min && fy < range_max) { + uint bucket = (fy & mask) >> shift; + atomicAdd(counts[bucket], 1); + } + } + barrier(); + + // On the first subgroup, do a scan to count (from the top down) how + // many elements are in the top N buckets. Find the index of the first + // that is over the limit. Copy it to the other invocations through + // shared memory. + if (tid < SUBGROUP_SIZE) { + uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid]; + partial_sum = subgroupInclusiveAdd(partial_sum) + total; + uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit)); + if (tid == t) { + sh_min_idx = int(SUBGROUP_SIZE - 1 - t); + sh_total = partial_sum; + } + } + barrier(); + int min_idx = sh_min_idx; + total = sh_total; + + // Update the range, and break if we've found the K-th largest. + range_max = range_min + ((min_idx + 1) << shift); + range_min = range_min + (min_idx << shift); + + if (total == p.ncols_output) { + break; + } + total -= counts[min_idx]; + mask >>= SUBGROUP_SIZE_LOG2; + shift -= SUBGROUP_SIZE_LOG2; + if (shift < 0) { + shift = 0; + } + } + + ivec2 v = dst_row[tid]; + + // We need to compact these values to the start of the dst_row array. + // Have each subgroup count how many items it'll store, so other + // subgroups can compute their base offset. + bool top = f2ui(intBitsToFloat(v.y)) >= range_min; + uvec4 b = subgroupBallot(top); + uint bit_count = subgroupBallotBitCount(b); + if ((tid % SUBGROUP_SIZE) == 0) { + offset_partials[tid / SUBGROUP_SIZE] = bit_count; + } + barrier(); + + uint out_idx = 0; + [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) { + if (i < tid / SUBGROUP_SIZE) { + out_idx += offset_partials[i]; + } + } + + uint bit_count_ex = subgroupBallotExclusiveBitCount(b); + if (top) { + // TODO: Copy directly to the output? + dst_row[out_idx + bit_count_ex] = v; + } + + barrier(); + } + + if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) { + if (p.last_pass != 0) { + const uint row_offset = row * p.ncols_output; + data_d[row_offset + tid] = dst_row[tid].x; + } else { + const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output; + data_t[row_offset + tid] = dst_row[tid]; + } + } +} + +void main() { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + topk(row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 164af35658..4a802ab1c2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -913,6 +913,9 @@ void process_shaders() { string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); + string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}}); + string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2fc7d4c3c7..d7ac2bc178 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7635,6 +7635,14 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) } + for (int i = 0; i < 20; ++i) { + for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) { + if (k <= 1<> make_test_cases_perf() { } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); - test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40)); + for (auto k : {1, 10, 40}) { + for (auto nrows : {1, 16}) { + for (auto cols : {k, 1000, 65000, 200000}) { + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k)); + } + } + } return test_cases; }