From f9889cf1c7f3115f31e0c5f3e60305a24052f8f9 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Thu, 27 Nov 2025 16:40:41 +0100 Subject: [PATCH] Fix top-k comp & behavior for non-CUB path Some changes were made in 5ea3be265ba6f8916daf52e19e3fb8efe9a03637 which were incomplete. In the case of non-CUB, bitonic sort and its limitations of ncols < 1024 have to apply, similar to argsort.cu --- ggml/src/ggml-cuda/ggml-cuda.cu | 1 - ggml/src/ggml-cuda/top-k.cu | 16 +++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ea4763fa6e..b6ac36a23c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4238,7 +4238,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_TOP_K: - return true; case GGML_OP_ARGSORT: #ifndef GGML_CUDA_USE_CUB return op->src[0]->ne[0] <= 1024; diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 912c41626b..7d66fec495 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -50,7 +50,7 @@ static void top_k_cub(ggml_cuda_pool & pool, DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols, k, env); } -#else +#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE static int next_power_of_2(int x) { int n = 1; @@ -84,7 +84,7 @@ void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { for (int i = 0; i < nrows; i++) { top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream); } -#else +#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE // Fall back to argsort + copy const int ncols_pad = next_power_of_2(ncols); const size_t shared_mem = ncols_pad * sizeof(int); @@ -94,15 +94,17 @@ void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { int * tmp_dst = temp_dst_alloc.get(); if (shared_mem > max_shared_mem || ncols > 1024) { -#ifdef GGML_CUDA_USE_CUB argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); -#else - argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); -#endif } else { argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); } CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, cudaMemcpyDeviceToDevice, stream)); -#endif // CUB_TOP_K_AVAILABLE +#else // GGML_CUDA_USE_CUB + ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows); + int * tmp_dst = temp_dst_alloc.get(); + argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, + cudaMemcpyDeviceToDevice, stream)); +#endif }