From f23b306cc5f76941d07c433f9d3a789b1b8b0988 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Fri, 21 Nov 2025 12:01:32 +0100 Subject: [PATCH] CUDA: Add top-k implementation --- cmake/CPM.cmake | 25 +++++++ ggml/src/ggml-cuda/CMakeLists.txt | 17 +++++ ggml/src/ggml-cuda/argsort.cu | 26 ++++---- ggml/src/ggml-cuda/argsort.cuh | 16 +++++ ggml/src/ggml-cuda/ggml-cuda.cu | 6 ++ ggml/src/ggml-cuda/top-k.cu | 104 ++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/top-k.cuh | 3 + tests/test-backend-ops.cpp | 5 ++ 8 files changed, 189 insertions(+), 13 deletions(-) create mode 100644 cmake/CPM.cmake create mode 100644 ggml/src/ggml-cuda/top-k.cu create mode 100644 ggml/src/ggml-cuda/top-k.cuh diff --git a/cmake/CPM.cmake b/cmake/CPM.cmake new file mode 100644 index 0000000000..978a1b7e39 --- /dev/null +++ b/cmake/CPM.cmake @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: MIT +# +# SPDX-FileCopyrightText: Copyright (c) 2019-2023 Lars Melchior and contributors + +# TODO: Remove this file once CCCL 3.2 is released & bundled with the CUDA Toolkit +set(CPM_DOWNLOAD_VERSION 0.42.0) +set(CPM_HASH_SUM "2020b4fc42dba44817983e06342e682ecfc3d2f484a581f11cc5731fbe4dce8a") + +if(CPM_SOURCE_CACHE) + set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +elseif(DEFINED ENV{CPM_SOURCE_CACHE}) + set(CPM_DOWNLOAD_LOCATION "$ENV{CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +else() + set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +endif() + +# Expand relative path. This is important if the provided path contains a tilde (~) +get_filename_component(CPM_DOWNLOAD_LOCATION ${CPM_DOWNLOAD_LOCATION} ABSOLUTE) + +file(DOWNLOAD + https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake + ${CPM_DOWNLOAD_LOCATION} EXPECTED_HASH SHA256=${CPM_HASH_SUM} +) + +include(${CPM_DOWNLOAD_LOCATION}) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 67af1d8ccc..05a9b49e83 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -2,6 +2,17 @@ cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES find_package(CUDAToolkit) +# Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit +if (GGML_CUDA_CUB_3DOT2) + include(../../../cmake/CPM.cmake) + # This will automatically clone CCCL from GitHub and make the exported cmake targets available + CPMAddPackage( + NAME CCCL + GITHUB_REPOSITORY nvidia/cccl + GIT_TAG v3.2.0-rc0 # Fetches the latest commit on the main branch + ) +endif() + if (CUDAToolkit_FOUND) message(STATUS "CUDA Toolkit found") @@ -102,6 +113,9 @@ if (CUDAToolkit_FOUND) # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas) else () + if (GGML_CUDA_CUB_3DOT2) + target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL) + endif() if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1") target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) else() @@ -109,6 +123,9 @@ if (CUDAToolkit_FOUND) endif() endif() else() + if (GGML_CUDA_CUB_3DOT2) + target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL) + endif() target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas) endif() diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index b8003c48c5..eb83e6547a 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -22,13 +22,13 @@ static __global__ void init_offsets(int * offsets, const int ncols, const int nr } #ifdef GGML_CUDA_USE_CUB -static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, - const float * x, - int * dst, - const int ncols, - const int nrows, - ggml_sort_order order, - cudaStream_t stream) { +void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { ggml_cuda_pool_alloc temp_indices_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); @@ -162,12 +162,12 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_cuda_bitonic(const float * x, - int * dst, - const int ncols, - const int nrows, - ggml_sort_order order, - cudaStream_t stream) { +void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index 68a001547f..22b7306f20 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -1,3 +1,19 @@ #include "common.cuh" void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +#ifdef GGML_CUDA_USE_CUB +void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream); +#endif // GGML_CUDA_USE_CUB +void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 98b0cea33a..ea4763fa6e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -44,6 +44,7 @@ #include "ggml-cuda/ssm-scan.cuh" #include "ggml-cuda/sum.cuh" #include "ggml-cuda/sumrows.cuh" +#include "ggml-cuda/top-k.cuh" #include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/topk-moe.cuh" @@ -2694,6 +2695,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SSM_SCAN: ggml_cuda_op_ssm_scan(ctx, dst); break; + case GGML_OP_TOP_K: + ggml_cuda_op_top_k(ctx, dst); + break; case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; @@ -4233,6 +4237,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CUMSUM: case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_TOP_K: + return true; case GGML_OP_ARGSORT: #ifndef GGML_CUDA_USE_CUB return op->src[0]->ne[0] <= 1024; diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu new file mode 100644 index 0000000000..ae5989cacc --- /dev/null +++ b/ggml/src/ggml-cuda/top-k.cu @@ -0,0 +1,104 @@ +#include "argsort.cuh" +#include "top-k.cuh" + +#ifdef GGML_CUDA_USE_CUB +# include +# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) +# define CUB_TOP_K_AVAILABLE +using namespace cub; +# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 +#endif // GGML_CUDA_USE_CUB + +#ifdef CUB_TOP_K_AVAILABLE +static __global__ void init_indices(int * indices, const int ncols) { + const int col = blockIdx.x * blockDim.x + threadIdx.x; + + if (col < ncols) { + indices[col] = col; + } +} + +static void top_k_cub(ggml_cuda_pool & pool, + const float * src, + int * dst, + const int ncols, + const int k, + cudaStream_t stream) { + auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed, + cuda::execution::output_ordering::unsorted); + auto stream_env = cuda::stream_ref{ stream }; + auto env = cuda::std::execution::env{ stream_env, requirements }; + + ggml_cuda_pool_alloc temp_indices_alloc(pool, ncols); + ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols); + + int * temp_indices = temp_indices_alloc.get(); + float * temp_keys = temp_keys_alloc.get(); + + static const int block_size = 256; + const dim3 grid_size((ncols + block_size - 1) / block_size, 1); + init_indices<<>>(temp_indices, ncols); + + CUDA_CHECK(cudaMemcpyAsync(temp_keys, src, ncols * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + + size_t temp_storage_bytes = 0; + DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols, k, env); + + ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); + void * d_temp_storage = temp_storage_alloc.get(); + + DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols, k, env); +} + +#else + +static int next_power_of_2(int x) { + int n = 1; + while (n < x) { + n *= 2; + } + return n; +} + +#endif // CUB_TOP_K_AVAILABLE + +void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + int * dst_d = (int *) dst->data; + cudaStream_t stream = ctx.stream(); + + // are these asserts truly necessary? + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + const int64_t k = dst->ne[0]; + ggml_cuda_pool & pool = ctx.pool(); +#ifdef CUB_TOP_K_AVAILABLE + // TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented + // https://github.com/NVIDIA/cccl/issues/6391 + // TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k + for (int i = 0; i < nrows; i++) { + top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream); + } +#else + // Fall back to argsort + copy + const int ncols_pad = next_power_of_2(ncols); + const size_t shared_mem = ncols_pad * sizeof(int); + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + + ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows); + int * tmp_dst = temp_dst_alloc.get(); + + if (shared_mem > max_shared_mem || ncols > 1024) { + argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + } else { + argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + } + CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, + cudaMemcpyDeviceToDevice, stream)); +#endif // CUB_TOP_K_AVAILABLE +} diff --git a/ggml/src/ggml-cuda/top-k.cuh b/ggml/src/ggml-cuda/top-k.cuh new file mode 100644 index 0000000000..85a95bc1bc --- /dev/null +++ b/ggml/src/ggml-cuda/top-k.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst); \ No newline at end of file diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index eb49eb41a6..f4072bf4a6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8035,6 +8035,11 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1})); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1})); test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 1, 1, 1}, 40)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 1, 1, 1}, 1)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 400)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 40)); + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 1)); return test_cases; }