cuda: fix race condition in cumsum (#18448)
* ggml-cuda: fix race condition in cumsum * remove unneccesary sync_threads
This commit is contained in:
parent
382808c14b
commit
5fa66c6e67
|
|
@ -69,11 +69,12 @@ static __global__ void cumsum_cub_kernel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// Update carry for next tile
|
// Update carry for next tile
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
block_carry += block_total;
|
block_carry += block_total;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
|
|
@ -175,11 +176,12 @@ static __global__ void cumsum_kernel(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// Update carry for next chunk
|
// Update carry for next chunk
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
*s_carry += *s_chunk_total;
|
*s_carry += *s_chunk_total;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue