CUDA: Add Cooperative-Groups-based parallelization of ncols in softmax

Old implementation parallelizes rows across SMs, which does not fit the
needs of backend-sampling (where we have ncols >> nrows and thus want to
parallelize ncols across SMs)
This commit is contained in:
Oliver Simons 2025-12-08 16:48:52 +01:00
parent 7ab6f51b97
commit a84dfd3e10
4 changed files with 202 additions and 15 deletions

View File

@ -912,15 +912,16 @@ struct ggml_cuda_device_info {
int device_count;
struct cuda_device_info {
int cc; // compute capability
int nsm; // number of streaming multiprocessors
size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool integrated; // Device is integrated as opposed to discrete
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
int cc; // compute capability
int nsm; // number of streaming multiprocessors
size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool integrated; // Device is integrated as opposed to discrete
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
size_t total_vram;
int warp_size; // Number of threads in a dispatch
int warp_size; // Number of threads in a dispatch
bool supports_cooperative_launch; // whether cooperative launch is supported
};
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};

View File

@ -242,6 +242,10 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].nsm = prop.multiProcessorCount;
info.devices[id].smpb = prop.sharedMemPerBlock;
info.devices[id].warp_size = prop.warpSize;
int supportsCoopLaunch = 0;
CUDA_CHECK(cudaDeviceGetAttribute(&supportsCoopLaunch, cudaDevAttrCooperativeLaunch, id));
info.devices[id].supports_cooperative_launch = !!supportsCoopLaunch;
#if defined(GGML_USE_HIP)
info.devices[id].smpbo = prop.sharedMemPerBlock;

View File

@ -1,6 +1,10 @@
#include "common.cuh"
#include "ggml.h"
#include "softmax.cuh"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cstdint>
#include <utility>
@ -160,6 +164,148 @@ static __global__ void soft_max_f32(
dst[col] = vals[col] * inv_sum;
}
}
// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
static __device__ float two_stage_warp_reduce_max(float val) {
val = warp_reduce_max(val);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float local_vals[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
local_vals[warp_id] = val;
}
__syncthreads();
val = -INFINITY;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
val = local_vals[lane_id];
}
return warp_reduce_max(val);
} else {
return val;
}
}
static __device__ float two_stage_warp_reduce_sum(float val) {
val = warp_reduce_sum(val);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float local_vals[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
local_vals[warp_id] = val;
}
__syncthreads();
val = 0.0f;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
val = local_vals[lane_id];
}
return warp_reduce_sum(val);
} else {
return val;
}
}
static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
float * __restrict__ dst,
float * __restrict__ tmp_vals,
const soft_max_params p) {
namespace cg = cooperative_groups;
const cg::grid_group g = cg::this_grid();
const int tid = threadIdx.x;
const int col_start = blockIdx.x * blockDim.x + tid;
const int n_elem_per_thread = 4;
float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
float local_max = -INFINITY;
const int step_size = gridDim.x * blockDim.x;
// Compute thread-local max
for (int col = col_start; col < p.ncols;) {
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
}
for (int i = 0; i < n_elem_per_thread; i++) {
local_max = fmaxf(local_max, local_vals[i]);
}
col += step_size * n_elem_per_thread;
}
// Compute CTA-level max
local_max = two_stage_warp_reduce_max(local_max);
// Store CTA-level max to GMEM
if (tid == 0) {
tmp_vals[blockIdx.x] = local_max;
}
g.sync();
// Compute compute global max from CTA-level maxs
assert(gridDim.x < blockDim.x); // currently we only support this case
if (tid < gridDim.x) {
local_max = tmp_vals[tid];
} else {
local_max = -INFINITY;
}
local_max = two_stage_warp_reduce_max(local_max);
// Compute softmax dividends, accumulate divisor
float tmp_expf = 0.0f;
for (int col = col_start; col < p.ncols;) {
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
}
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
if (idx < p.ncols) {
const float tmp = expf(local_vals[i] - local_max);
tmp_expf += tmp;
dst[idx] = tmp;
}
}
col += step_size * n_elem_per_thread;
}
// Reduce divisor within CTA
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
// Store CTA-level sum to GMEM
if (tid == 0) {
tmp_vals[blockIdx.x] = tmp_expf;
}
g.sync();
// Compute global sum from CTA-level sums
if (tid < gridDim.x) {
tmp_expf = tmp_vals[tid];
} else {
tmp_expf = 0.0f;
}
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
// Divide dividend by global sum + store data
for (int col = col_start; col < p.ncols;) {
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
}
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
if (idx < p.ncols) {
dst[idx] = local_vals[i] / tmp_expf;
}
}
col += step_size * n_elem_per_thread;
}
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
@ -216,9 +362,30 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
}
static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
float * __restrict__ dst,
float * __restrict__ tmp_vals,
const soft_max_params p)
// We loop over all instead of parallelizing across gridDim.y as cooperative groups
// currently only support synchronizing the complete grid if not launched as a cluster group
// (which requires CC > 9.0)
// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization
// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group
{
for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {
soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_vals,
p);
}
}
template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
template <typename T>
static void soft_max_f32_cuda(const float * x,
const T * mask,
const float * sinks,
float * dst,
const soft_max_params & params,
cudaStream_t stream,
[[maybe_unused]] ggml_backend_cuda_context & ctx) {
int nth = WARP_SIZE;
const int64_t ncols_x = params.ncols;
@ -236,8 +403,21 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const float * sin
if (nbytes_shared <= smpbo) {
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else {
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
// Parallelize across SMs for top-p/dist-smapling
if (ncols_x > 10000 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) {
if (ggml_cuda_info().devices[id].supports_cooperative_launch) {
ggml_cuda_pool_alloc<float> tmp_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_alloc.ptr, (void *) &params };
CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
}
} else {
const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
soft_max_f32<false, 0, 0>
<<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
}
}
}
@ -315,9 +495,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
params.m1 = m1;
if (use_f16) {
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
} else {
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
}
}

View File

@ -7588,6 +7588,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
exponent <<= 1;
}
#endif
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 1, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 4, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {643251, 3, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
for (bool mask : {false, true}) {
for (bool sinks : {false, true}) {
for (float max_bias : {0.0f, 8.0f}) {
@ -7638,7 +7641,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
}
for (bool fw : {true, false}) { // fw == forward
bool all = true;