96 lines
4.0 KiB
Plaintext
96 lines
4.0 KiB
Plaintext
#include "argsort.cuh"
|
|
#include "top-k.cuh"
|
|
|
|
#ifdef GGML_CUDA_USE_CUB
|
|
# include <cub/cub.cuh>
|
|
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
|
|
# define CUB_TOP_K_AVAILABLE
|
|
using namespace cub;
|
|
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
|
|
#endif // GGML_CUDA_USE_CUB
|
|
|
|
#ifdef CUB_TOP_K_AVAILABLE
|
|
|
|
static void top_k_cub(ggml_cuda_pool & pool,
|
|
const float * src,
|
|
int * dst,
|
|
const int ncols,
|
|
const int k,
|
|
cudaStream_t stream) {
|
|
auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
|
|
cuda::execution::output_ordering::unsorted);
|
|
auto stream_env = cuda::stream_ref{ stream };
|
|
auto env = cuda::std::execution::env{ stream_env, requirements };
|
|
|
|
auto indexes_in = cuda::make_counting_iterator(0);
|
|
|
|
size_t temp_storage_bytes = 0;
|
|
DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
|
|
env);
|
|
|
|
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
|
void * d_temp_storage = temp_storage_alloc.get();
|
|
|
|
DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
|
|
ncols, k, env);
|
|
}
|
|
|
|
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
|
|
|
|
static int next_power_of_2(int x) {
|
|
int n = 1;
|
|
while (n < x) {
|
|
n *= 2;
|
|
}
|
|
return n;
|
|
}
|
|
|
|
#endif // CUB_TOP_K_AVAILABLE
|
|
|
|
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
const float * src0_d = (const float *) src0->data;
|
|
int * dst_d = (int *) dst->data;
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
// are these asserts truly necessary?
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
const int64_t ncols = src0->ne[0];
|
|
const int64_t nrows = ggml_nrows(src0);
|
|
const int64_t k = dst->ne[0];
|
|
ggml_cuda_pool & pool = ctx.pool();
|
|
#ifdef CUB_TOP_K_AVAILABLE
|
|
// TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
|
|
// https://github.com/NVIDIA/cccl/issues/6391
|
|
// TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
|
|
for (int i = 0; i < nrows; i++) {
|
|
top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
|
|
}
|
|
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
|
|
// Fall back to argsort + copy
|
|
const int ncols_pad = next_power_of_2(ncols);
|
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
|
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
|
|
|
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
|
|
int * tmp_dst = temp_dst_alloc.get();
|
|
|
|
if (shared_mem > max_shared_mem || ncols > 1024) {
|
|
argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
|
|
} else {
|
|
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
|
|
}
|
|
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
|
|
cudaMemcpyDeviceToDevice, stream));
|
|
#else // GGML_CUDA_USE_CUB
|
|
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
|
|
int * tmp_dst = temp_dst_alloc.get();
|
|
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
|
|
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
|
|
cudaMemcpyDeviceToDevice, stream));
|
|
#endif
|
|
}
|