Fix data-race in `soft_max_f32_parallelize_cols_single_row`
By using `tmp_vals` to store both max values and exponential accumulator there was a potential data-race, where the exponential accumulator for a given CTA may have written to `tmp_vals` before all others CTAs have read the max value from it. To avoid a third g.sync(), an additional temporary data-storage was added. Given that there are syncs in place after writing to gmem, it is guaranteed that the previous values for sums/max were read by all CTAs now.
This commit is contained in:
parent
2652e745ef
commit
3732b85b09
|
|
@ -216,7 +216,8 @@ static __device__ float two_stage_warp_reduce_sum(float val) {
|
|||
// TODO: Template to allow keeping ncols in registers if they fit
|
||||
static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
|
||||
float * __restrict__ dst,
|
||||
float * __restrict__ tmp_vals,
|
||||
float * __restrict__ tmp_maxs,
|
||||
float * __restrict__ tmp_sums,
|
||||
const soft_max_params p) {
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
|
|
@ -249,14 +250,14 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
|
|||
|
||||
// Store CTA-level max to GMEM
|
||||
if (tid == 0) {
|
||||
tmp_vals[blockIdx.x] = local_max;
|
||||
tmp_maxs[blockIdx.x] = local_max;
|
||||
}
|
||||
g.sync();
|
||||
|
||||
// Compute compute global max from CTA-level maxs
|
||||
assert(gridDim.x < blockDim.x); // currently we only support this case
|
||||
if (tid < gridDim.x) {
|
||||
local_max = tmp_vals[tid];
|
||||
local_max = tmp_maxs[tid];
|
||||
} else {
|
||||
local_max = -INFINITY;
|
||||
}
|
||||
|
|
@ -287,13 +288,13 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
|
|||
|
||||
// Store CTA-level sum to GMEM
|
||||
if (tid == 0) {
|
||||
tmp_vals[blockIdx.x] = tmp_expf;
|
||||
tmp_sums[blockIdx.x] = tmp_expf;
|
||||
}
|
||||
g.sync();
|
||||
|
||||
// Compute global sum from CTA-level sums
|
||||
if (tid < gridDim.x) {
|
||||
tmp_expf = tmp_vals[tid];
|
||||
tmp_expf = tmp_sums[tid];
|
||||
} else {
|
||||
tmp_expf = 0.0f;
|
||||
}
|
||||
|
|
@ -375,7 +376,8 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float
|
|||
|
||||
__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
|
||||
float * __restrict__ dst,
|
||||
float * __restrict__ tmp_vals,
|
||||
float * __restrict__ tmp_maxs,
|
||||
float * __restrict__ tmp_sums,
|
||||
const soft_max_params p)
|
||||
// We loop over all instead of parallelizing across gridDim.y as cooperative groups
|
||||
// currently only support synchronizing the complete grid if not launched as a cluster group
|
||||
|
|
@ -384,8 +386,8 @@ __launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_paralleliz
|
|||
// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group
|
||||
{
|
||||
for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {
|
||||
soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_vals,
|
||||
p);
|
||||
soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs,
|
||||
tmp_sums, p);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -418,12 +420,14 @@ static void soft_max_f32_cuda(const float * x,
|
|||
// The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
|
||||
// Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
|
||||
if (ggml_cuda_info().devices[id].supports_cooperative_launch && ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) {
|
||||
ggml_cuda_pool_alloc<float> tmp_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
|
||||
ggml_cuda_pool_alloc<float> tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
|
||||
ggml_cuda_pool_alloc<float> tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
|
||||
|
||||
void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_alloc.ptr, (void *) const_cast<soft_max_params *>(& params)};
|
||||
CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
|
||||
dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
|
||||
dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
|
||||
void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr,
|
||||
(void *) &tmp_sums_alloc.ptr, (void *) const_cast<soft_max_params *>(¶ms) };
|
||||
CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
|
||||
dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
|
||||
dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
|
||||
} else {
|
||||
const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
|
||||
soft_max_f32<false, 0, 0>
|
||||
|
|
|
|||
Loading…
Reference in New Issue