From a84dfd3e1072bcf422b90dc4b03d334395c90fe4 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Mon, 8 Dec 2025 16:48:52 +0100 Subject: [PATCH] CUDA: Add Cooperative-Groups-based parallelization of ncols in softmax Old implementation parallelizes rows across SMs, which does not fit the needs of backend-sampling (where we have ncols >> nrows and thus want to parallelize ncols across SMs) --- ggml/src/ggml-cuda/common.cuh | 17 +-- ggml/src/ggml-cuda/ggml-cuda.cu | 4 + ggml/src/ggml-cuda/softmax.cu | 192 +++++++++++++++++++++++++++++++- tests/test-backend-ops.cpp | 4 +- 4 files changed, 202 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c4529f5d94..97fb716dba 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -912,15 +912,16 @@ struct ggml_cuda_device_info { int device_count; struct cuda_device_info { - int cc; // compute capability - int nsm; // number of streaming multiprocessors - size_t smpb; // max. shared memory per block - size_t smpbo; // max. shared memory per block (with opt-in) - bool integrated; // Device is integrated as opposed to discrete - bool vmm; // virtual memory support - size_t vmm_granularity; // granularity of virtual memory + int cc; // compute capability + int nsm; // number of streaming multiprocessors + size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) + bool integrated; // Device is integrated as opposed to discrete + bool vmm; // virtual memory support + size_t vmm_granularity; // granularity of virtual memory size_t total_vram; - int warp_size; // Number of threads in a dispatch + int warp_size; // Number of threads in a dispatch + bool supports_cooperative_launch; // whether cooperative launch is supported }; cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 556db4d2b0..e18a1ce801 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -242,6 +242,10 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].nsm = prop.multiProcessorCount; info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].warp_size = prop.warpSize; + + int supportsCoopLaunch = 0; + CUDA_CHECK(cudaDeviceGetAttribute(&supportsCoopLaunch, cudaDevAttrCooperativeLaunch, id)); + info.devices[id].supports_cooperative_launch = !!supportsCoopLaunch; #if defined(GGML_USE_HIP) info.devices[id].smpbo = prop.sharedMemPerBlock; diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index eeacde0bdb..9b13614cfd 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -1,6 +1,10 @@ #include "common.cuh" #include "ggml.h" #include "softmax.cuh" + +#include +#include + #include #include @@ -160,6 +164,148 @@ static __global__ void soft_max_f32( dst[col] = vals[col] * inv_sum; } } + + +// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated +static __device__ float two_stage_warp_reduce_max(float val) { + val = warp_reduce_max(val); + if (blockDim.x > WARP_SIZE) { + assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); + __shared__ float local_vals[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + local_vals[warp_id] = val; + } + __syncthreads(); + val = -INFINITY; + if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { + val = local_vals[lane_id]; + } + return warp_reduce_max(val); + } else { + return val; + } +} + +static __device__ float two_stage_warp_reduce_sum(float val) { + val = warp_reduce_sum(val); + if (blockDim.x > WARP_SIZE) { + assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); + __shared__ float local_vals[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + local_vals[warp_id] = val; + } + __syncthreads(); + val = 0.0f; + if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { + val = local_vals[lane_id]; + } + return warp_reduce_sum(val); + } else { + return val; + } +} + +static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, + float * __restrict__ dst, + float * __restrict__ tmp_vals, + const soft_max_params p) { + namespace cg = cooperative_groups; + + const cg::grid_group g = cg::this_grid(); + + const int tid = threadIdx.x; + const int col_start = blockIdx.x * blockDim.x + tid; + const int n_elem_per_thread = 4; + + float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY }; + float local_max = -INFINITY; + const int step_size = gridDim.x * blockDim.x; + + // Compute thread-local max + for (int col = col_start; col < p.ncols;) { + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; + } + for (int i = 0; i < n_elem_per_thread; i++) { + local_max = fmaxf(local_max, local_vals[i]); + } + col += step_size * n_elem_per_thread; + } + + // Compute CTA-level max + local_max = two_stage_warp_reduce_max(local_max); + + // Store CTA-level max to GMEM + if (tid == 0) { + tmp_vals[blockIdx.x] = local_max; + } + g.sync(); + + // Compute compute global max from CTA-level maxs + assert(gridDim.x < blockDim.x); // currently we only support this case + if (tid < gridDim.x) { + local_max = tmp_vals[tid]; + } else { + local_max = -INFINITY; + } + local_max = two_stage_warp_reduce_max(local_max); + + // Compute softmax dividends, accumulate divisor + float tmp_expf = 0.0f; + for (int col = col_start; col < p.ncols;) { + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; + } + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + if (idx < p.ncols) { + const float tmp = expf(local_vals[i] - local_max); + tmp_expf += tmp; + dst[idx] = tmp; + } + } + col += step_size * n_elem_per_thread; + } + + // Reduce divisor within CTA + tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + + // Store CTA-level sum to GMEM + if (tid == 0) { + tmp_vals[blockIdx.x] = tmp_expf; + } + g.sync(); + + // Compute global sum from CTA-level sums + if (tid < gridDim.x) { + tmp_expf = tmp_vals[tid]; + } else { + tmp_expf = 0.0f; + } + tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + + // Divide dividend by global sum + store data + for (int col = col_start; col < p.ncols;) { + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY; + } + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + if (idx < p.ncols) { + dst[idx] = local_vals[i] / tmp_expf; + } + } + col += step_size * n_elem_per_thread; + } +} + #ifdef __clang__ #pragma clang diagnostic pop #endif // __clang__ @@ -216,9 +362,30 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float soft_max_f32<<>>(x, mask, sinks, dst, p); } +static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x, + float * __restrict__ dst, + float * __restrict__ tmp_vals, + const soft_max_params p) +// We loop over all instead of parallelizing across gridDim.y as cooperative groups +// currently only support synchronizing the complete grid if not launched as a cluster group +// (which requires CC > 9.0) +// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization +// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group +{ + for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) { + soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_vals, + p); + } +} -template -static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) { +template +static void soft_max_f32_cuda(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params & params, + cudaStream_t stream, + [[maybe_unused]] ggml_backend_cuda_context & ctx) { int nth = WARP_SIZE; const int64_t ncols_x = params.ncols; @@ -236,8 +403,21 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const float * sin if (nbytes_shared <= smpbo) { launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); } else { - const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, sinks, dst, params); + // Parallelize across SMs for top-p/dist-smapling + if (ncols_x > 10000 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) { + if (ggml_cuda_info().devices[id].supports_cooperative_launch) { + ggml_cuda_pool_alloc tmp_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); + + void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_alloc.ptr, (void *) ¶ms }; + CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols, + dim3(ggml_cuda_info().devices[id].nsm, 1, 1), + dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream)); + } + } else { + const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); + soft_max_f32 + <<>>(x, mask, sinks, dst, params); + } } } @@ -315,9 +495,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { params.m1 = m1; if (use_f16) { - soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); } else { - soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f569f7a7f8..2777cd2b82 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7588,6 +7588,9 @@ static std::vector> make_test_cases_eval() { exponent <<= 1; } #endif + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 1, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 4, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {643251, 3, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); for (bool mask : {false, true}) { for (bool sinks : {false, true}) { for (float max_bias : {0.0f, 8.0f}) { @@ -7638,7 +7641,6 @@ static std::vector> make_test_cases_eval() { } } } - for (bool fw : {true, false}) { // fw == forward bool all = true;