cuda: fix race condition in cumsum (#18448)

* ggml-cuda: fix race condition in cumsum

* remove unneccesary sync_threads
This commit is contained in:
Aman Gupta 2025-12-29 14:07:17 +08:00 committed by GitHub
parent 382808c14b
commit 5fa66c6e67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 3 deletions

View File

@ -61,7 +61,7 @@ static __global__ void cumsum_cub_kernel(
// Add offset to each item and store // Add offset to each item and store
T thread_offset = thread_prefix - thread_sum + block_carry; T thread_offset = thread_prefix - thread_sum + block_carry;
#pragma unroll #pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) { for (int i = 0; i < UNROLL_FACTOR; i++) {
int64_t idx = start + tid * UNROLL_FACTOR + i; int64_t idx = start + tid * UNROLL_FACTOR + i;
if (idx < ne00) { if (idx < ne00) {
@ -69,11 +69,12 @@ static __global__ void cumsum_cub_kernel(
} }
} }
__syncthreads();
// Update carry for next tile // Update carry for next tile
if (tid == 0) { if (tid == 0) {
block_carry += block_total; block_carry += block_total;
} }
__syncthreads();
} }
#else #else
NO_DEVICE_CODE; NO_DEVICE_CODE;
@ -175,11 +176,12 @@ static __global__ void cumsum_kernel(
} }
} }
__syncthreads();
// Update carry for next chunk // Update carry for next chunk
if (tid == 0) { if (tid == 0) {
*s_carry += *s_chunk_total; *s_carry += *s_chunk_total;
} }
__syncthreads();
} }
} }