diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 90794ff264..eaaf87612d 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -530,6 +530,86 @@ static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { #endif // FP16_AVAILABLE } +enum class block_reduce_method { + MAX, + SUM, +}; + +template +struct block_reduce_policy; + +template +inline constexpr bool is_any = (std::is_same_v || ...); + +template +inline constexpr bool ggml_cuda_dependent_false_v = false; + +template struct block_reduce_policy { + static __device__ T reduce(T val) { + if constexpr(is_any) { + return warp_reduce_sum(val); + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce sum"); + } + } + + static __device__ T sentinel() { + if constexpr (std::is_same_v) { + return 0.0f; + } else if constexpr (std::is_same_v) { + return make_float2(0.0f, 0.0f); + } else if constexpr (std::is_same_v) { + return make_half2(0.0f, 0.0f); + } else if constexpr (std::is_same_v) { + return 0; + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce sum"); + } + } +}; + +template struct block_reduce_policy { + static __device__ T reduce(T val) { + if constexpr (is_any) { + return warp_reduce_max(val); + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max"); + } + } + + static __device__ T sentinel() { + if constexpr (std::is_same_v) { + return -INFINITY; + } else if constexpr (std::is_same_v) { + return make_half2(-INFINITY, -INFINITY); + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max"); + } + } +}; + +template +static __device__ T block_reduce(T val, T * shared_vals) { + val = block_reduce_policy::reduce(val); + const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + if (block_size > WARP_SIZE) { + assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0); + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + shared_vals[warp_id] = val; + } + __syncthreads(); + val = block_reduce_policy::sentinel(); + if (lane_id < (static_cast(block_size) / WARP_SIZE)) { + val = shared_vals[lane_id]; + } + return block_reduce_policy::reduce(val); + } + + return val; +} + static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { #ifdef FP16_AVAILABLE diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c3ee2ea066..553623fbd4 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4551,7 +4551,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_L2_NORM: return true; case GGML_OP_RMS_NORM_BACK: - return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; + return ggml_is_contiguous(op->src[0]); break; case GGML_OP_NONE: case GGML_OP_RESHAPE: diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 4f153c5718..ef98f675aa 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -25,19 +25,8 @@ static __global__ void norm_f32( } // sum up partial sums - mean_var = warp_reduce_sum(mean_var); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float2 s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = mean_var; - } - __syncthreads(); - mean_var = s_sum[lane_id]; - mean_var = warp_reduce_sum(mean_var); - } + extern __shared__ float2 s_sum2[]; + mean_var = block_reduce(mean_var, s_sum2); const float mean = mean_var.x / ncols; const float var = mean_var.y / ncols - mean * mean; @@ -61,19 +50,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp += x[j]; } - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce(tmp, s_sum); const float mean = tmp / group_size; tmp = 0.0f; @@ -84,18 +62,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp += xi * xi; } - tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + tmp = block_reduce(tmp, s_sum); const float variance = tmp / group_size; const float scale = rsqrtf(variance + eps); @@ -163,22 +130,8 @@ static __global__ void rms_norm_f32(const float * x, } // sum up partial sums - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = tid / WARP_SIZE; - const int lane_id = tid % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = 0.0f; - if (lane_id < (block_size / WARP_SIZE)) { - tmp = s_sum[lane_id]; - } - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce(tmp, s_sum); const float mean = tmp / ncols; const float scale = rsqrtf(mean + eps); @@ -306,19 +259,8 @@ static __global__ void l2_norm_f32( } // sum up partial sums - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce(tmp, s_sum); // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html const float scale = rsqrtf(fmaxf(tmp, eps * eps)); @@ -337,7 +279,7 @@ static void norm_f32_cuda( norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -348,7 +290,7 @@ static void group_norm_f32_cuda( group_norm_f32<<>>(x, dst, group_size, ne_elements, eps); } else { const dim3 block_dims(1024, 1, 1); - group_norm_f32<1024><<>>(x, dst, group_size, ne_elements, eps); + group_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps); } } @@ -358,10 +300,10 @@ static void rms_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -404,12 +346,12 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true><<>>( + rms_norm_f32<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><<>>( + rms_norm_f32<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } @@ -425,14 +367,14 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true, true><<>>( + rms_norm_f32<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, add_nchannels_packed, add_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><<>>( + rms_norm_f32<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, @@ -460,7 +402,7 @@ static void l2_norm_f32_cuda( l2_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - l2_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + l2_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index 6bcae9e52f..de240fd441 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -28,22 +28,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r } // sum up partial sums - sum = warp_reduce_sum(sum); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = sum; - } - __syncthreads(); - sum = 0.0f; - if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { - sum = s_sum[lane_id]; - } - sum = warp_reduce_sum(sum); - } + __shared__ float shared_vals[32]; + sum = block_reduce(sum, shared_vals); if (col != 0) { return; diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 1ae84ebf63..dc06d06930 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -75,9 +75,6 @@ static __global__ void soft_max_f32( const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); extern __shared__ float data_soft_max_f32[]; @@ -102,21 +99,7 @@ static __global__ void soft_max_f32( } // find the max value in the block - max_val = warp_reduce_max(max_val); - if (block_size > WARP_SIZE) { - if (warp_id == 0) { - buf_iw[lane_id] = -INFINITY; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = max_val; - } - __syncthreads(); - - max_val = buf_iw[lane_id]; - max_val = warp_reduce_max(max_val); - } + max_val = block_reduce(max_val, buf_iw); float tmp = 0.0f; // partial sum @@ -134,22 +117,7 @@ static __global__ void soft_max_f32( } // find the sum of exps in the block - tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { - __syncthreads(); - if (warp_id == 0) { - buf_iw[lane_id] = 0.0f; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = tmp; - } - __syncthreads(); - - tmp = buf_iw[lane_id]; - tmp = warp_reduce_sum(tmp); - } + tmp = block_reduce(tmp, buf_iw); if (sinks) { tmp += expf(sinks[i02] - max_val); @@ -169,50 +137,6 @@ static __global__ void soft_max_f32( } } - -// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated -static __device__ float two_stage_warp_reduce_max(float val) { - val = warp_reduce_max(val); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float local_vals[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - local_vals[warp_id] = val; - } - __syncthreads(); - val = -INFINITY; - if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { - val = local_vals[lane_id]; - } - return warp_reduce_max(val); - } else { - return val; - } -} - -static __device__ float two_stage_warp_reduce_sum(float val) { - val = warp_reduce_sum(val); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float local_vals[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - local_vals[warp_id] = val; - } - __syncthreads(); - val = 0.0f; - if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { - val = local_vals[lane_id]; - } - return warp_reduce_sum(val); - } else { - return val; - } -} - // TODO: Template to allow keeping ncols in registers if they fit static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, float * __restrict__ dst, @@ -230,6 +154,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY }; float local_max = -INFINITY; const int step_size = gridDim.x * blockDim.x; + __shared__ float shared_vals[32]; // Compute thread-local max for (int col = col_start; col < p.ncols;) { @@ -246,7 +171,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } // Compute CTA-level max - local_max = two_stage_warp_reduce_max(local_max); + local_max = block_reduce(local_max, shared_vals); // Store CTA-level max to GMEM if (tid == 0) { @@ -261,7 +186,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } else { local_max = -INFINITY; } - local_max = two_stage_warp_reduce_max(local_max); + local_max = block_reduce(local_max, shared_vals); // Compute softmax dividends, accumulate divisor float tmp_expf = 0.0f; @@ -284,7 +209,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } // Reduce divisor within CTA - tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + tmp_expf = block_reduce(tmp_expf, shared_vals); // Store CTA-level sum to GMEM if (tid == 0) { @@ -298,7 +223,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } else { tmp_expf = 0.0f; } - tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + tmp_expf = block_reduce(tmp_expf, shared_vals); // Divide dividend by global sum + store data for (int col = col_start; col < p.ncols;) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 19ef58404e..188ffdf3db 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7482,25 +7482,29 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f)); test_cases.emplace_back(new test_silu_back()); - for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) { - for (bool v : {false, true}) { - test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps)); - test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps)); + for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f }) { + for (uint32_t n : { 64, 1025 }) { + for (bool v : { false, true }) { + test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps)); + test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps)); + } + test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps)); + test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps)); } - test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); - test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } // in-place tests test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true)); - for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { - test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); - test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); - test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); - test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); - test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); - test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); + for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f }) { + for (uint32_t n : { 64, 1025 }) { + test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false)); + test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true)); + test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false)); + test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true)); + test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false)); + test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true)); + } } for (uint32_t n : {1, 511, 1025, 8192, 33*512}) { for (bool multi_add : {false, true}) { @@ -7524,9 +7528,6 @@ static std::vector> make_test_cases_eval() { } } } - - test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f)); - for (int64_t d_conv : {3, 4, 9}) { for (int64_t d_inner: {1024, 1536, 2048}) { test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));