workaround unit test failure for TOP_K

This commit is contained in:
Nakasaka, Masato 2026-04-13 01:11:39 -07:00
parent b5249e9f43
commit 04b7af9563
1 changed files with 6 additions and 8 deletions

View File

@ -610,7 +610,6 @@ struct vk_device_struct {
bool support_async;
bool async_use_transfer_queue;
uint32_t subgroup_size;
uint32_t subgroup_size_log2;
uint32_t shader_core_count;
bool uma;
bool prefer_host_memory;
@ -3176,7 +3175,7 @@ static const std::unordered_map<std::string, PipelineConfigParameter> xe2_onward
// Intel GPU can use subgroup 8, 16, or 32 depending on architeture.
// Pre-Xe2 is 8, 16, or 32. Xe2 onward is 16 or 32. 32 is the default if nothing is specified.
static constexpr uint32_t INTEL_DEFAULT_SUBGROUP_SIZE = 32;
static constexpr uint32_t INTEL_DEFAULT_SUBGROUP_SIZE = 16;
// Define configurations for different GPUs.
static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
@ -3255,6 +3254,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
std::lock_guard<std::recursive_mutex> guard(device->mutex);
const uint32_t default_subgroup_size = get_subgroup_size(device);
const uint32_t subgroup_size_log2 = uint32_t(log2f(float(default_subgroup_size)));
// some shaders have a minimum subgroup size
const uint32_t subgroup_size_8 = std::max(default_subgroup_size, 8u);
@ -4669,10 +4669,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
2 * sizeof(int) +
2 * (BLOCK_SIZE / default_subgroup_size) * sizeof(int);
if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, default_subgroup_size, device->subgroup_size_log2}, 1, true, true, default_subgroup_size);
nary_shmem <= device->properties.limits.maxComputeSharedMemorySize &&
BLOCK_SIZE >= default_subgroup_size) { // The n-ary top-k shader needs at least one full subgroup per workgroup.
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_nary_search_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, default_subgroup_size, subgroup_size_log2}, 1, true, true, default_subgroup_size);
} else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_argsort_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
}
}
}
@ -5109,7 +5110,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
device->subgroup_size = subgroup_props.subgroupSize;
device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
if (sm_builtins) {
device->shader_core_count = sm_props.shaderSMCount;
@ -11535,8 +11535,6 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);
max_pipeline = std::min(preferred_pipeline, max_pipeline);
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
// require full subgroup
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
pipeline_idx = std::min(pipeline_idx, max_pipeline);