vulkan: Implement top-k (#17418)
* vulkan: Implement top-k Each pass launches workgroups that each sort 2^N elements (where N is usually 7-10) and discards all but the top K. Repeat until only K are left. And there's a fast path when K==1 to just find the max value rather than sorting. * fix pipeline selection * vulkan: Add N-ary search algorithm for topk * microoptimizations
This commit is contained in:
parent
6ab4e50d9c
commit
879d673759
|
|
@ -409,6 +409,7 @@ enum shader_reduction_mode {
|
||||||
// argsort pipelines for up to 1<<10 invocations per workgroup
|
// argsort pipelines for up to 1<<10 invocations per workgroup
|
||||||
static constexpr uint32_t num_argsort_pipelines = 11;
|
static constexpr uint32_t num_argsort_pipelines = 11;
|
||||||
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||||
|
static constexpr uint32_t num_topk_pipelines = 11;
|
||||||
|
|
||||||
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||||
|
|
@ -515,6 +516,7 @@ struct vk_device_struct {
|
||||||
bool single_queue;
|
bool single_queue;
|
||||||
bool support_async;
|
bool support_async;
|
||||||
uint32_t subgroup_size;
|
uint32_t subgroup_size;
|
||||||
|
uint32_t subgroup_size_log2;
|
||||||
uint32_t shader_core_count;
|
uint32_t shader_core_count;
|
||||||
bool uma;
|
bool uma;
|
||||||
bool prefer_host_memory;
|
bool prefer_host_memory;
|
||||||
|
|
@ -704,6 +706,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
||||||
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
||||||
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
||||||
|
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
|
||||||
vk_pipeline pipeline_sum_rows_f32;
|
vk_pipeline pipeline_sum_rows_f32;
|
||||||
vk_pipeline pipeline_cumsum_f32;
|
vk_pipeline pipeline_cumsum_f32;
|
||||||
vk_pipeline pipeline_argmax_f32;
|
vk_pipeline pipeline_argmax_f32;
|
||||||
|
|
@ -1205,6 +1208,15 @@ struct vk_op_argsort_push_constants {
|
||||||
uint32_t inner_end;
|
uint32_t inner_end;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct vk_op_topk_push_constants {
|
||||||
|
uint32_t orig_ncols;
|
||||||
|
uint32_t ncols_input;
|
||||||
|
uint32_t ncols_output;
|
||||||
|
uint32_t nrows;
|
||||||
|
uint32_t first_pass;
|
||||||
|
uint32_t last_pass;
|
||||||
|
};
|
||||||
|
|
||||||
struct vk_op_im2col_push_constants {
|
struct vk_op_im2col_push_constants {
|
||||||
uint64_t dst_addr;
|
uint64_t dst_addr;
|
||||||
uint32_t batch_offset; uint32_t offset_delta;
|
uint32_t batch_offset; uint32_t offset_delta;
|
||||||
|
|
@ -3965,6 +3977,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
|
ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
|
||||||
|
const uint32_t BLOCK_SIZE = 1u << i;
|
||||||
|
const uint32_t NCOLS_PADDED_LOG2 = i;
|
||||||
|
if (i <= device->max_workgroup_size_log2) {
|
||||||
|
uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
|
||||||
|
sizeof(int) * device->subgroup_size +
|
||||||
|
2 * sizeof(int) +
|
||||||
|
(BLOCK_SIZE / device->subgroup_size) * sizeof(int);
|
||||||
|
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
|
||||||
|
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||||
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
|
||||||
|
} else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
|
||||||
|
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
@ -4336,6 +4365,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
|
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
|
||||||
|
|
||||||
device->subgroup_size = subgroup_props.subgroupSize;
|
device->subgroup_size = subgroup_props.subgroupSize;
|
||||||
|
device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
|
||||||
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||||
if (sm_builtins) {
|
if (sm_builtins) {
|
||||||
device->shader_core_count = sm_props.shaderSMCount;
|
device->shader_core_count = sm_props.shaderSMCount;
|
||||||
|
|
@ -10143,6 +10173,104 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
|
uint32_t ncols = src0->ne[0];
|
||||||
|
uint32_t nrows = ggml_nrows(src0);
|
||||||
|
uint32_t k = dst->ne[0];
|
||||||
|
|
||||||
|
vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
|
||||||
|
|
||||||
|
// Reserve space for ivec2 per element, double buffered
|
||||||
|
const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
|
||||||
|
const size_t x_sz = dbl_buf_size * 2;
|
||||||
|
uint32_t dbl_buf_index = 0;
|
||||||
|
|
||||||
|
if (ctx->prealloc_size_x < x_sz) {
|
||||||
|
ctx->prealloc_size_x = x_sz;
|
||||||
|
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
if (ctx->prealloc_x_need_sync) {
|
||||||
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<uint32_t, 3> elements;
|
||||||
|
elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||||
|
elements[2] = 1;
|
||||||
|
|
||||||
|
uint32_t num_elements = ncols;
|
||||||
|
|
||||||
|
// Each iteration reduces a workgroup's worth of elements down to the K
|
||||||
|
// largest elements. Repeat until we have the top K elements.
|
||||||
|
// Need to do at least one iteration to write out the results.
|
||||||
|
bool done_one_iter = false;
|
||||||
|
while (num_elements > k || !done_one_iter) {
|
||||||
|
done_one_iter = true;
|
||||||
|
|
||||||
|
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
|
||||||
|
// But if K is larger, then we need a larger workgroup
|
||||||
|
uint32_t max_pipeline = num_topk_pipelines - 3;
|
||||||
|
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
|
||||||
|
// require full subgroup
|
||||||
|
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
|
||||||
|
|
||||||
|
uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
|
||||||
|
pipeline_idx = std::min(pipeline_idx, max_pipeline);
|
||||||
|
pipeline_idx = std::max(pipeline_idx, min_pipeline);
|
||||||
|
|
||||||
|
if (num_elements > (1u << pipeline_idx)) {
|
||||||
|
// If we could finish on this loop iteration (i.e. a single workgroup)
|
||||||
|
// then do so. It's better than the overhead of another pass.
|
||||||
|
for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
|
||||||
|
if (num_elements <= (1u << i)) {
|
||||||
|
pipeline_idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
|
||||||
|
// If the device doesn't support a pipeline this large, use smaller
|
||||||
|
while (!pipeline) {
|
||||||
|
pipeline_idx--;
|
||||||
|
GGML_ASSERT(pipeline_idx >= min_pipeline);
|
||||||
|
pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_op_topk_push_constants pc2 = pc;
|
||||||
|
pc2.ncols_input = num_elements;
|
||||||
|
|
||||||
|
// Number of elements remaining after this pass
|
||||||
|
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
|
||||||
|
|
||||||
|
vk_subbuffer src_buf;
|
||||||
|
vk_subbuffer dst_buf;
|
||||||
|
|
||||||
|
if (num_elements == ncols) {
|
||||||
|
pc2.first_pass = 1;
|
||||||
|
src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
|
||||||
|
} else {
|
||||||
|
src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
|
||||||
|
}
|
||||||
|
if (num_dst_elements == k) {
|
||||||
|
pc2.last_pass = 1;
|
||||||
|
dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||||
|
} else {
|
||||||
|
dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
|
||||||
|
}
|
||||||
|
|
||||||
|
elements[0] = num_elements;
|
||||||
|
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||||
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
|
||||||
|
num_elements = num_dst_elements;
|
||||||
|
dbl_buf_index ^= 1;
|
||||||
|
if (num_elements > k) {
|
||||||
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx->prealloc_x_need_sync = true;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
|
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
|
||||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
|
||||||
|
|
@ -11755,6 +11883,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
ggml_vk_argsort(ctx, compute_ctx, src0, node);
|
ggml_vk_argsort(ctx, compute_ctx, src0, node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
|
ggml_vk_topk(ctx, compute_ctx, src0, node);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
ggml_vk_sum(ctx, compute_ctx, src0, node);
|
ggml_vk_sum(ctx, compute_ctx, src0, node);
|
||||||
|
|
@ -13787,6 +13919,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
|
return op->ne[0] <= (1 << device->max_workgroup_size_log2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case GGML_OP_TOP_K:
|
||||||
|
{
|
||||||
|
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
auto device = ggml_vk_get_device(ctx->device);
|
||||||
|
// We could potentially support larger, using argsort to sort the
|
||||||
|
// whole thing. Not clear if this is needed.
|
||||||
|
uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
|
||||||
|
if (min_pipeline >= num_topk_pipelines ||
|
||||||
|
!device->pipeline_topk_f32[min_pipeline]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
|
|
@ -14459,6 +14607,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
|
||||||
} else if (tensor->op == GGML_OP_ARGSORT) {
|
} else if (tensor->op == GGML_OP_ARGSORT) {
|
||||||
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
|
tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
|
||||||
|
} else if (tensor->op == GGML_OP_TOP_K) {
|
||||||
|
tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
|
||||||
} else if (tensor->op == GGML_OP_SUM) {
|
} else if (tensor->op == GGML_OP_SUM) {
|
||||||
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
|
||||||
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
#version 450
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||||
|
layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
// Input can either be the source (A) or intermediate values (S).
|
||||||
|
// Similarly, output can be either destination (D) or intermediate values (S).
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {int data_d[];};
|
||||||
|
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
uint orig_ncols;
|
||||||
|
uint ncols_input;
|
||||||
|
uint ncols_output;
|
||||||
|
uint nrows;
|
||||||
|
uint first_pass;
|
||||||
|
uint last_pass;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
// pairs of (gid, value)
|
||||||
|
shared ivec2 dst_row[BLOCK_SIZE];
|
||||||
|
|
||||||
|
void topk(bool needs_bounds_check, const uint row) {
|
||||||
|
const int col = int(gl_LocalInvocationID.x);
|
||||||
|
|
||||||
|
// initialize indices
|
||||||
|
if (gl_GlobalInvocationID.x < p.ncols_input) {
|
||||||
|
if (p.first_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_input;
|
||||||
|
dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols;
|
||||||
|
dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dst_row[col] = ivec2(p.orig_ncols, 0);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
if (p.ncols_output == 1) {
|
||||||
|
// Fast path for single output - just do a max reduction
|
||||||
|
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
||||||
|
if (col < s) {
|
||||||
|
ivec2 a = dst_row[col];
|
||||||
|
ivec2 b = dst_row[col + s];
|
||||||
|
if (a.x >= p.orig_ncols ||
|
||||||
|
b.x < p.orig_ncols && b.y > a.y) {
|
||||||
|
dst_row[col] = b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// bitonic sort on this group of elements
|
||||||
|
uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
|
||||||
|
for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
|
||||||
|
uint num_inner_loop_iters = outer_idx + 1;
|
||||||
|
for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
|
||||||
|
const int ixj = int(col ^ j);
|
||||||
|
|
||||||
|
int idx_0 = (col & k) == 0 ? col : ixj;
|
||||||
|
int idx_1 = (col & k) == 0 ? ixj : col;
|
||||||
|
|
||||||
|
ivec2 sh_idx_0 = dst_row[idx_0];
|
||||||
|
ivec2 sh_idx_1 = dst_row[idx_1];
|
||||||
|
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;
|
||||||
|
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;
|
||||||
|
|
||||||
|
if ((idx_0_oob ||
|
||||||
|
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
|
||||||
|
dst_row[idx_0] = sh_idx_1;
|
||||||
|
dst_row[idx_1] = sh_idx_0;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
|
||||||
|
if (p.last_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_output;
|
||||||
|
data_d[row_offset + col] = dst_row[col].x;
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
|
||||||
|
data_t[row_offset + col] = dst_row[col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
// Fast path for fully occupied workgroups
|
||||||
|
if ((p.ncols_input % BLOCK_SIZE) == 0) {
|
||||||
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
topk(false, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
topk(true, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,199 @@
|
||||||
|
#version 450
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_debug_printf : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_basic : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||||
|
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||||
|
layout(constant_id = 1) const int SUBGROUP_SIZE = 32;
|
||||||
|
layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5;
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
// Input can either be the source (A) or intermediate values (S).
|
||||||
|
// Similarly, output can be either destination (D) or intermediate values (S).
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 0) readonly buffer S {ivec2 data_s[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {int data_d[];};
|
||||||
|
layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter {
|
||||||
|
uint orig_ncols;
|
||||||
|
uint ncols_input;
|
||||||
|
uint ncols_output;
|
||||||
|
uint nrows;
|
||||||
|
uint first_pass;
|
||||||
|
uint last_pass;
|
||||||
|
} p;
|
||||||
|
|
||||||
|
// pairs of (gid, value)
|
||||||
|
shared ivec2 dst_row[BLOCK_SIZE];
|
||||||
|
|
||||||
|
shared int counts[SUBGROUP_SIZE];
|
||||||
|
shared int sh_min_idx;
|
||||||
|
shared uint sh_total;
|
||||||
|
shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||||
|
|
||||||
|
// Map float values to uint such that comparisons still work.
|
||||||
|
// Positive values set the high bit, negative values are inverted.
|
||||||
|
// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places.
|
||||||
|
uint f2ui(float x) {
|
||||||
|
uint y = floatBitsToUint(x);
|
||||||
|
if ((y & 0x80000000) != 0) {
|
||||||
|
y ^= ~0;
|
||||||
|
} else {
|
||||||
|
y |= 0x80000000;
|
||||||
|
}
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
void topk(const uint row) {
|
||||||
|
const int tid = int(gl_LocalInvocationID.x);
|
||||||
|
|
||||||
|
// initialize indices
|
||||||
|
if (gl_GlobalInvocationID.x < p.ncols_input) {
|
||||||
|
if (p.first_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_input;
|
||||||
|
dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols;
|
||||||
|
dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
if (p.ncols_output == 1) {
|
||||||
|
// Fast path for single output - just do a max reduction
|
||||||
|
[[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
|
||||||
|
if (tid < s) {
|
||||||
|
ivec2 a = dst_row[tid];
|
||||||
|
ivec2 b = dst_row[tid + s];
|
||||||
|
if (a.x >= p.orig_ncols ||
|
||||||
|
b.x < p.orig_ncols && b.y > a.y) {
|
||||||
|
dst_row[tid] = b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Do an N-ary search to find the K-th largest value.
|
||||||
|
// We remap the float values to be comparable as unsigned integers,
|
||||||
|
// and split the range into 2^N smaller ranges where N is the
|
||||||
|
// subgroup size. Count how many values are in each range, if the K-th
|
||||||
|
// largest value is in the middle of one of thee ranges then repeat
|
||||||
|
// and split again.
|
||||||
|
|
||||||
|
// Mask is the current set of bits we're searching. Shift is the LSB index.
|
||||||
|
int shift = 32 - SUBGROUP_SIZE_LOG2;
|
||||||
|
uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift;
|
||||||
|
|
||||||
|
// The current range.
|
||||||
|
uint range_min = 0;
|
||||||
|
uint range_max = 0xFF800000;
|
||||||
|
// How many are above the current range, and how many we need to find.
|
||||||
|
uint total = 0;
|
||||||
|
uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
|
||||||
|
|
||||||
|
while (mask != 0) {
|
||||||
|
barrier();
|
||||||
|
// Initialize bucket counts to zero.
|
||||||
|
if (tid < SUBGROUP_SIZE) {
|
||||||
|
counts[tid] = 0;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
// Count how many values are in each bucket.
|
||||||
|
if (tid < p.ncols_input) {
|
||||||
|
float y = intBitsToFloat(dst_row[tid].y);
|
||||||
|
uint fy = f2ui(y);
|
||||||
|
if (fy >= range_min && fy < range_max) {
|
||||||
|
uint bucket = (fy & mask) >> shift;
|
||||||
|
atomicAdd(counts[bucket], 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
// On the first subgroup, do a scan to count (from the top down) how
|
||||||
|
// many elements are in the top N buckets. Find the index of the first
|
||||||
|
// that is over the limit. Copy it to the other invocations through
|
||||||
|
// shared memory.
|
||||||
|
if (tid < SUBGROUP_SIZE) {
|
||||||
|
uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];
|
||||||
|
partial_sum = subgroupInclusiveAdd(partial_sum) + total;
|
||||||
|
uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));
|
||||||
|
if (tid == t) {
|
||||||
|
sh_min_idx = int(SUBGROUP_SIZE - 1 - t);
|
||||||
|
sh_total = partial_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
int min_idx = sh_min_idx;
|
||||||
|
total = sh_total;
|
||||||
|
|
||||||
|
// Update the range, and break if we've found the K-th largest.
|
||||||
|
range_max = range_min + ((min_idx + 1) << shift);
|
||||||
|
range_min = range_min + (min_idx << shift);
|
||||||
|
|
||||||
|
if (total == p.ncols_output) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
total -= counts[min_idx];
|
||||||
|
mask >>= SUBGROUP_SIZE_LOG2;
|
||||||
|
shift -= SUBGROUP_SIZE_LOG2;
|
||||||
|
if (shift < 0) {
|
||||||
|
shift = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ivec2 v = dst_row[tid];
|
||||||
|
|
||||||
|
// We need to compact these values to the start of the dst_row array.
|
||||||
|
// Have each subgroup count how many items it'll store, so other
|
||||||
|
// subgroups can compute their base offset.
|
||||||
|
bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
|
||||||
|
uvec4 b = subgroupBallot(top);
|
||||||
|
uint bit_count = subgroupBallotBitCount(b);
|
||||||
|
if ((tid % SUBGROUP_SIZE) == 0) {
|
||||||
|
offset_partials[tid / SUBGROUP_SIZE] = bit_count;
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
uint out_idx = 0;
|
||||||
|
[[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
|
||||||
|
if (i < tid / SUBGROUP_SIZE) {
|
||||||
|
out_idx += offset_partials[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
|
||||||
|
if (top) {
|
||||||
|
// TODO: Copy directly to the output?
|
||||||
|
dst_row[out_idx + bit_count_ex] = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
|
||||||
|
if (p.last_pass != 0) {
|
||||||
|
const uint row_offset = row * p.ncols_output;
|
||||||
|
data_d[row_offset + tid] = dst_row[tid].x;
|
||||||
|
} else {
|
||||||
|
const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
|
||||||
|
data_t[row_offset + tid] = dst_row[tid];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
topk(row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -913,6 +913,9 @@ void process_shaders() {
|
||||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||||
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
||||||
|
|
||||||
|
string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
|
||||||
|
string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
|
||||||
|
|
||||||
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
||||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||||
|
|
|
||||||
|
|
@ -7635,6 +7635,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 20; ++i) {
|
||||||
|
for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) {
|
||||||
|
if (k <= 1<<i) {
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
for (int k : {1, 2, 3, 7, 15}) {
|
for (int k : {1, 2, 3, 7, 15}) {
|
||||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
|
||||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
|
||||||
|
|
@ -8032,7 +8040,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
|
||||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40));
|
for (auto k : {1, 10, 40}) {
|
||||||
|
for (auto nrows : {1, 16}) {
|
||||||
|
for (auto cols : {k, 1000, 65000, 200000}) {
|
||||||
|
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return test_cases;
|
return test_cases;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue