CUDA: Add top-k implementation
This commit is contained in:
parent
ec047e12ee
commit
f23b306cc5
|
|
@ -0,0 +1,25 @@
|
|||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2019-2023 Lars Melchior and contributors
|
||||
|
||||
# TODO: Remove this file once CCCL 3.2 is released & bundled with the CUDA Toolkit
|
||||
set(CPM_DOWNLOAD_VERSION 0.42.0)
|
||||
set(CPM_HASH_SUM "2020b4fc42dba44817983e06342e682ecfc3d2f484a581f11cc5731fbe4dce8a")
|
||||
|
||||
if(CPM_SOURCE_CACHE)
|
||||
set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
|
||||
elseif(DEFINED ENV{CPM_SOURCE_CACHE})
|
||||
set(CPM_DOWNLOAD_LOCATION "$ENV{CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
|
||||
else()
|
||||
set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
|
||||
endif()
|
||||
|
||||
# Expand relative path. This is important if the provided path contains a tilde (~)
|
||||
get_filename_component(CPM_DOWNLOAD_LOCATION ${CPM_DOWNLOAD_LOCATION} ABSOLUTE)
|
||||
|
||||
file(DOWNLOAD
|
||||
https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake
|
||||
${CPM_DOWNLOAD_LOCATION} EXPECTED_HASH SHA256=${CPM_HASH_SUM}
|
||||
)
|
||||
|
||||
include(${CPM_DOWNLOAD_LOCATION})
|
||||
|
|
@ -2,6 +2,17 @@ cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
|
|||
|
||||
find_package(CUDAToolkit)
|
||||
|
||||
# Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
|
||||
if (GGML_CUDA_CUB_3DOT2)
|
||||
include(../../../cmake/CPM.cmake)
|
||||
# This will automatically clone CCCL from GitHub and make the exported cmake targets available
|
||||
CPMAddPackage(
|
||||
NAME CCCL
|
||||
GITHUB_REPOSITORY nvidia/cccl
|
||||
GIT_TAG v3.2.0-rc0 # Fetches the latest commit on the main branch
|
||||
)
|
||||
endif()
|
||||
|
||||
if (CUDAToolkit_FOUND)
|
||||
message(STATUS "CUDA Toolkit found")
|
||||
|
||||
|
|
@ -102,6 +113,9 @@ if (CUDAToolkit_FOUND)
|
|||
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
|
||||
else ()
|
||||
if (GGML_CUDA_CUB_3DOT2)
|
||||
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
|
||||
endif()
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
else()
|
||||
|
|
@ -109,6 +123,9 @@ if (CUDAToolkit_FOUND)
|
|||
endif()
|
||||
endif()
|
||||
else()
|
||||
if (GGML_CUDA_CUB_3DOT2)
|
||||
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
|
||||
endif()
|
||||
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@ static __global__ void init_offsets(int * offsets, const int ncols, const int nr
|
|||
}
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
|
|
@ -162,12 +162,12 @@ static int next_power_of_2(int x) {
|
|||
return n;
|
||||
}
|
||||
|
||||
static void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,19 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream);
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream);
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@
|
|||
#include "ggml-cuda/ssm-scan.cuh"
|
||||
#include "ggml-cuda/sum.cuh"
|
||||
#include "ggml-cuda/sumrows.cuh"
|
||||
#include "ggml-cuda/top-k.cuh"
|
||||
#include "ggml-cuda/mean.cuh"
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/topk-moe.cuh"
|
||||
|
|
@ -2694,6 +2695,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_SSM_SCAN:
|
||||
ggml_cuda_op_ssm_scan(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TOP_K:
|
||||
ggml_cuda_op_top_k(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_cuda_op_argsort(ctx, dst);
|
||||
break;
|
||||
|
|
@ -4233,6 +4237,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_TOP_K:
|
||||
return true;
|
||||
case GGML_OP_ARGSORT:
|
||||
#ifndef GGML_CUDA_USE_CUB
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,104 @@
|
|||
#include "argsort.cuh"
|
||||
#include "top-k.cuh"
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/cub.cuh>
|
||||
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
|
||||
# define CUB_TOP_K_AVAILABLE
|
||||
using namespace cub;
|
||||
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
#ifdef CUB_TOP_K_AVAILABLE
|
||||
static __global__ void init_indices(int * indices, const int ncols) {
|
||||
const int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (col < ncols) {
|
||||
indices[col] = col;
|
||||
}
|
||||
}
|
||||
|
||||
static void top_k_cub(ggml_cuda_pool & pool,
|
||||
const float * src,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int k,
|
||||
cudaStream_t stream) {
|
||||
auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
|
||||
cuda::execution::output_ordering::unsorted);
|
||||
auto stream_env = cuda::stream_ref{ stream };
|
||||
auto env = cuda::std::execution::env{ stream_env, requirements };
|
||||
|
||||
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols);
|
||||
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols);
|
||||
|
||||
int * temp_indices = temp_indices_alloc.get();
|
||||
float * temp_keys = temp_keys_alloc.get();
|
||||
|
||||
static const int block_size = 256;
|
||||
const dim3 grid_size((ncols + block_size - 1) / block_size, 1);
|
||||
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols);
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(temp_keys, src, ncols * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols, k, env);
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||
void * d_temp_storage = temp_storage_alloc.get();
|
||||
|
||||
DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols, k, env);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
static int next_power_of_2(int x) {
|
||||
int n = 1;
|
||||
while (n < x) {
|
||||
n *= 2;
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
#endif // CUB_TOP_K_AVAILABLE
|
||||
|
||||
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
int * dst_d = (int *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
// are these asserts truly necessary?
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
const int64_t k = dst->ne[0];
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
#ifdef CUB_TOP_K_AVAILABLE
|
||||
// TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
|
||||
// https://github.com/NVIDIA/cccl/issues/6391
|
||||
// TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
|
||||
for (int i = 0; i < nrows; i++) {
|
||||
top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
|
||||
}
|
||||
#else
|
||||
// Fall back to argsort + copy
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
||||
|
||||
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
|
||||
int * tmp_dst = temp_dst_alloc.get();
|
||||
|
||||
if (shared_mem > max_shared_mem || ncols > 1024) {
|
||||
argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
|
||||
} else {
|
||||
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
|
||||
}
|
||||
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
#endif // CUB_TOP_K_AVAILABLE
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -8035,6 +8035,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1}));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 1, 1, 1}, 40));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 1, 1, 1}, 1));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 400));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 40));
|
||||
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {200000, 1, 1, 1}, 1));
|
||||
|
||||
return test_cases;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue