ggml : optimize cuda cumsum fallback kernel (#18343)
This commit is contained in:
parent
f5acfb2ffa
commit
c54bba869d
|
|
@ -69,7 +69,7 @@ static __global__ void cumsum_cub_kernel(
|
|||
#endif // GGML_CUDA_USE_CUB
|
||||
}
|
||||
|
||||
// Fallback kernel implementation (original)
|
||||
// Fallback kernel implementation
|
||||
template<typename T>
|
||||
static __global__ void cumsum_kernel(
|
||||
const T * src, T * dst,
|
||||
|
|
@ -86,10 +86,10 @@ static __global__ void cumsum_kernel(
|
|||
const int warps_per_block = blockDim.x / warp_size;
|
||||
|
||||
extern __shared__ float smem[];
|
||||
float * s_vals = smem;
|
||||
float * s_warp_sums = smem + blockDim.x;
|
||||
float * s_carry = smem + blockDim.x + warps_per_block;
|
||||
float * s_chunk_total = s_carry + 1;
|
||||
float * s_vals = smem;
|
||||
float * s_warp_sums = smem + blockDim.x;
|
||||
float * s_carry = smem + blockDim.x + warps_per_block;
|
||||
float * s_chunk_total = s_carry + 1;
|
||||
|
||||
// Initialize carry
|
||||
if (tid == 0) {
|
||||
|
|
@ -107,21 +107,39 @@ static __global__ void cumsum_kernel(
|
|||
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
|
||||
for (int64_t start = 0; start < ne00; start += blockDim.x) {
|
||||
int64_t idx = start + tid;
|
||||
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
|
||||
// register blocking: process 4 elements per thread to hide latency
|
||||
// and reduce synchronization overhead
|
||||
constexpr int num_unroll = 4;
|
||||
T temp[num_unroll];
|
||||
|
||||
// 1. Warp inclusive scan
|
||||
for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
|
||||
int64_t idx = i + tid * num_unroll;
|
||||
|
||||
// thread local sequential scan
|
||||
temp[0] = (idx < ne00 ? src_row[idx] : T(0));
|
||||
#pragma unroll
|
||||
for (int64_t j = 1; j < num_unroll; j++) {
|
||||
temp[j] = temp[j - 1];
|
||||
if (idx + j < ne00) {
|
||||
temp[j] += src_row[idx + j];
|
||||
} else {
|
||||
temp[j] += 0;
|
||||
}
|
||||
}
|
||||
|
||||
// last emenent is sum of all values assigned to thread
|
||||
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
|
||||
|
||||
// Warp inclusive scan
|
||||
val = warp_prefix_inclusive_sum<T, warp_size>(val);
|
||||
s_vals[tid] = val;
|
||||
|
||||
// Store warp total
|
||||
if (lane == warp_size - 1) {
|
||||
s_warp_sums[warp] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// 2. Exclusive scan of warp sums (warp 0 only)
|
||||
// Exclusive scan of warp sums (warp 0 only)
|
||||
if (warp == 0) {
|
||||
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
|
||||
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
|
||||
|
|
@ -134,12 +152,17 @@ static __global__ void cumsum_kernel(
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
// write back results
|
||||
float carry = *s_carry;
|
||||
float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
|
||||
if (idx < ne00) {
|
||||
dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
|
||||
// calculate sum offset for this thread
|
||||
float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < num_unroll; j++) {
|
||||
if (idx + j < ne00) {
|
||||
dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Update carry for next chunk
|
||||
if (tid == 0) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue