From c54bba869da97feeaf788cac499679c9a9ff9fea Mon Sep 17 00:00:00 2001 From: Aadeshveer Singh <24b0926@iitb.ac.in> Date: Thu, 25 Dec 2025 09:41:13 +0530 Subject: [PATCH] ggml : optimize cuda cumsum fallback kernel (#18343) --- ggml/src/ggml-cuda/cumsum.cu | 53 ++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index d2f2def8bd..0f72e33bba 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -69,7 +69,7 @@ static __global__ void cumsum_cub_kernel( #endif // GGML_CUDA_USE_CUB } -// Fallback kernel implementation (original) +// Fallback kernel implementation template static __global__ void cumsum_kernel( const T * src, T * dst, @@ -86,10 +86,10 @@ static __global__ void cumsum_kernel( const int warps_per_block = blockDim.x / warp_size; extern __shared__ float smem[]; - float * s_vals = smem; - float * s_warp_sums = smem + blockDim.x; - float * s_carry = smem + blockDim.x + warps_per_block; - float * s_chunk_total = s_carry + 1; + float * s_vals = smem; + float * s_warp_sums = smem + blockDim.x; + float * s_carry = smem + blockDim.x + warps_per_block; + float * s_chunk_total = s_carry + 1; // Initialize carry if (tid == 0) { @@ -107,21 +107,39 @@ static __global__ void cumsum_kernel( const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3; - for (int64_t start = 0; start < ne00; start += blockDim.x) { - int64_t idx = start + tid; - float val = (idx < ne00) ? ggml_cuda_cast(src_row[idx]) : 0.0f; + // register blocking: process 4 elements per thread to hide latency + // and reduce synchronization overhead + constexpr int num_unroll = 4; + T temp[num_unroll]; - // 1. Warp inclusive scan + for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) { + int64_t idx = i + tid * num_unroll; + + // thread local sequential scan + temp[0] = (idx < ne00 ? src_row[idx] : T(0)); +#pragma unroll + for (int64_t j = 1; j < num_unroll; j++) { + temp[j] = temp[j - 1]; + if (idx + j < ne00) { + temp[j] += src_row[idx + j]; + } else { + temp[j] += 0; + } + } + + // last emenent is sum of all values assigned to thread + float val = (idx < ne00) ? ggml_cuda_cast(temp[num_unroll - 1]) : 0.0f; + + // Warp inclusive scan val = warp_prefix_inclusive_sum(val); s_vals[tid] = val; - // Store warp total if (lane == warp_size - 1) { s_warp_sums[warp] = val; } __syncthreads(); - // 2. Exclusive scan of warp sums (warp 0 only) + // Exclusive scan of warp sums (warp 0 only) if (warp == 0) { float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; float inc = warp_prefix_inclusive_sum(w); @@ -134,12 +152,17 @@ static __global__ void cumsum_kernel( } __syncthreads(); + // write back results float carry = *s_carry; - float final_val = s_vals[tid] + s_warp_sums[warp] + carry; - if (idx < ne00) { - dst_row[idx] = ggml_cuda_cast(final_val); + // calculate sum offset for this thread + float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1]; + +#pragma unroll + for (int32_t j = 0; j < num_unroll; j++) { + if (idx + j < ne00) { + dst_row[idx + j] = temp[j] + ggml_cuda_cast(final_val_offset); + } } - __syncthreads(); // Update carry for next chunk if (tid == 0) {