From 3732b85b09ef5f821e92d6a370a74f554abb7945 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Mon, 15 Dec 2025 11:01:12 +0100 Subject: [PATCH] Fix data-race in `soft_max_f32_parallelize_cols_single_row` By using `tmp_vals` to store both max values and exponential accumulator there was a potential data-race, where the exponential accumulator for a given CTA may have written to `tmp_vals` before all others CTAs have read the max value from it. To avoid a third g.sync(), an additional temporary data-storage was added. Given that there are syncs in place after writing to gmem, it is guaranteed that the previous values for sums/max were read by all CTAs now. --- ggml/src/ggml-cuda/softmax.cu | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 4dffb1c168..4773d93f7e 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -216,7 +216,8 @@ static __device__ float two_stage_warp_reduce_sum(float val) { // TODO: Template to allow keeping ncols in registers if they fit static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, float * __restrict__ dst, - float * __restrict__ tmp_vals, + float * __restrict__ tmp_maxs, + float * __restrict__ tmp_sums, const soft_max_params p) { namespace cg = cooperative_groups; @@ -249,14 +250,14 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ // Store CTA-level max to GMEM if (tid == 0) { - tmp_vals[blockIdx.x] = local_max; + tmp_maxs[blockIdx.x] = local_max; } g.sync(); // Compute compute global max from CTA-level maxs assert(gridDim.x < blockDim.x); // currently we only support this case if (tid < gridDim.x) { - local_max = tmp_vals[tid]; + local_max = tmp_maxs[tid]; } else { local_max = -INFINITY; } @@ -287,13 +288,13 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ // Store CTA-level sum to GMEM if (tid == 0) { - tmp_vals[blockIdx.x] = tmp_expf; + tmp_sums[blockIdx.x] = tmp_expf; } g.sync(); // Compute global sum from CTA-level sums if (tid < gridDim.x) { - tmp_expf = tmp_vals[tid]; + tmp_expf = tmp_sums[tid]; } else { tmp_expf = 0.0f; } @@ -375,7 +376,8 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float __launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x, float * __restrict__ dst, - float * __restrict__ tmp_vals, + float * __restrict__ tmp_maxs, + float * __restrict__ tmp_sums, const soft_max_params p) // We loop over all instead of parallelizing across gridDim.y as cooperative groups // currently only support synchronizing the complete grid if not launched as a cluster group @@ -384,8 +386,8 @@ __launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_paralleliz // https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group { for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) { - soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_vals, - p); + soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs, + tmp_sums, p); } } @@ -418,12 +420,14 @@ static void soft_max_f32_cuda(const float * x, // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution. if (ggml_cuda_info().devices[id].supports_cooperative_launch && ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) { - ggml_cuda_pool_alloc tmp_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); + ggml_cuda_pool_alloc tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); + ggml_cuda_pool_alloc tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); - void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_alloc.ptr, (void *) const_cast(& params)}; - CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols, - dim3(ggml_cuda_info().devices[id].nsm, 1, 1), - dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream)); + void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr, + (void *) &tmp_sums_alloc.ptr, (void *) const_cast(¶ms) }; + CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols, + dim3(ggml_cuda_info().devices[id].nsm, 1, 1), + dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream)); } else { const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); soft_max_f32