cuda: fix race condition in cumsum (#18448)
* ggml-cuda: fix race condition in cumsum * remove unneccesary sync_threads
This commit is contained in:
parent
382808c14b
commit
5fa66c6e67
|
|
@ -61,7 +61,7 @@ static __global__ void cumsum_cub_kernel(
|
||||||
|
|
||||||
// Add offset to each item and store
|
// Add offset to each item and store
|
||||||
T thread_offset = thread_prefix - thread_sum + block_carry;
|
T thread_offset = thread_prefix - thread_sum + block_carry;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < UNROLL_FACTOR; i++) {
|
for (int i = 0; i < UNROLL_FACTOR; i++) {
|
||||||
int64_t idx = start + tid * UNROLL_FACTOR + i;
|
int64_t idx = start + tid * UNROLL_FACTOR + i;
|
||||||
if (idx < ne00) {
|
if (idx < ne00) {
|
||||||
|
|
@ -69,11 +69,12 @@ static __global__ void cumsum_cub_kernel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// Update carry for next tile
|
// Update carry for next tile
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
block_carry += block_total;
|
block_carry += block_total;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
|
|
@ -175,11 +176,12 @@ static __global__ void cumsum_kernel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// Update carry for next chunk
|
// Update carry for next chunk
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
*s_carry += *s_chunk_total;
|
*s_carry += *s_chunk_total;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue