diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index ef98f675aa..ec3b21f5ec 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -150,6 +150,93 @@ static __global__ void rms_norm_f32(const float * x, } } +template +static __global__ void rms_norm_f32_vec4(const float * x, + float * dst, + const int ncols, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const float eps, + const float * mul = nullptr, + const int64_t mul_stride_row = 0, + const int64_t mul_stride_channel = 0, + const int64_t mul_stride_sample = 0, + const uint3 mul_ncols_packed = make_uint3(0, 0, 0), + const uint3 mul_nrows_packed = make_uint3(0, 0, 0), + const uint3 mul_nchannels_packed = make_uint3(0, 0, 0), + const uint3 mul_nsamples_packed = make_uint3(0, 0, 0), + const float * add = nullptr, + const int64_t add_stride_row = 0, + const int64_t add_stride_channel = 0, + const int64_t add_stride_sample = 0, + const uint3 add_ncols_packed = make_uint3(0, 0, 0), + const uint3 add_nrows_packed = make_uint3(0, 0, 0), + const uint3 add_nchannels_packed = make_uint3(0, 0, 0), + const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying"); + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; + + if constexpr (do_multiply) { + const uint32_t mul_row = fastmodulo(row, mul_nrows_packed); + const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed); + const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed); + mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row; + } + if constexpr (do_add) { + const int add_row = fastmodulo(row, add_nrows_packed); + const int add_channel = fastmodulo(channel, add_nchannels_packed); + const int add_sample = fastmodulo(sample, add_nsamples_packed); + add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row; + } + + const int ncols4 = ncols / 4; + const float4 * x4 = reinterpret_cast(x); + float tmp = 0.0f; + + for (int col4 = tid; col4 < ncols4; col4 += block_size) { + const float4 v = x4[col4]; + tmp += v.x*v.x + v.y*v.y + v.z*v.z + v.w*v.w; + } + + // sum up partial sums + extern __shared__ float s_sum[]; + tmp = block_reduce(tmp, s_sum); + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + float4 * dst4 = reinterpret_cast(dst); + for (int col4 = tid; col4 < ncols4; col4 += block_size) { + float4 v = x4[col4]; + if constexpr (do_multiply && do_add) { + const int b = col4*4; + v.x = scale * v.x * mul[fastmodulo(b, mul_ncols_packed)] + add[fastmodulo(b, add_ncols_packed)]; + v.y = scale * v.y * mul[fastmodulo(b+1, mul_ncols_packed)] + add[fastmodulo(b+1, add_ncols_packed)]; + v.z = scale * v.z * mul[fastmodulo(b+2, mul_ncols_packed)] + add[fastmodulo(b+2, add_ncols_packed)]; + v.w = scale * v.w * mul[fastmodulo(b+3, mul_ncols_packed)] + add[fastmodulo(b+3, add_ncols_packed)]; + } else if constexpr (do_multiply) { + const int b = col4*4; + v.x = scale * v.x * mul[fastmodulo(b, mul_ncols_packed)]; + v.y = scale * v.y * mul[fastmodulo(b+1, mul_ncols_packed)]; + v.z = scale * v.z * mul[fastmodulo(b+2, mul_ncols_packed)]; + v.w = scale * v.w * mul[fastmodulo(b+3, mul_ncols_packed)]; + } else { + v.x *= scale; v.y *= scale; v.z *= scale; v.w *= scale; + } + dst4[col4] = v; + } +} + template static __global__ void rms_norm_back_f32( const float * grad, const float * xf, float * dst, const int ncols, const float eps) { @@ -298,12 +385,24 @@ static void rms_norm_f32_cuda( const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { const dim3 blocks_num(nrows, nchannels, nsamples); + const bool use_vec4 = (ncols % 4 == 0) + && (stride_row % 4 == 0) && (stride_channel % 4 == 0) && (stride_sample % 4 == 0) + && (reinterpret_cast(x) % sizeof(float4) == 0) + && (reinterpret_cast(dst) % sizeof(float4) == 0); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + if (use_vec4) { + rms_norm_f32_vec4<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + rms_norm_f32<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + if (use_vec4) { + rms_norm_f32_vec4<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + rms_norm_f32<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } } } @@ -339,6 +438,10 @@ static void rms_norm_mul_f32_cuda(const float * x, rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); return; } + const bool use_vec4 = (ncols % 4 == 0) + && (stride_row % 4 == 0) && (stride_channel % 4 == 0) && (stride_sample % 4 == 0) + && (reinterpret_cast(x) % sizeof(float4) == 0) + && (reinterpret_cast(dst) % sizeof(float4) == 0); if (add == nullptr) { const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); @@ -346,14 +449,26 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( - x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + if (use_vec4) { + rms_norm_f32_vec4<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + } else { + rms_norm_f32<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + } } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( - x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + if (use_vec4) { + rms_norm_f32_vec4<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + } else { + rms_norm_f32<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + } } } else { const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); @@ -367,18 +482,34 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( - x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, - add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, - add_nchannels_packed, add_nsamples_packed); + if (use_vec4) { + rms_norm_f32_vec4<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); + } else { + rms_norm_f32<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); + } } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( - x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, - add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, - add_nchannels_packed, add_nsamples_packed); + if (use_vec4) { + rms_norm_f32_vec4<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); + } else { + rms_norm_f32<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); + } } } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index abf914faa1..b92da383d2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8980,6 +8980,13 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 1024, 1)); // 4h PP-1024 test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1, 1, false, true)); // KDA PP-64 + // rms_norm: float4 vectorized path (ncols divisible by 4) and scalar fallback (ncols not divisible by 4) + for (uint32_t n : {4, 128, 256, 512, 768, 1024, 2048, 3072, 4096, 5120, 8192, 3, 5, 13, 127, 4097}) { + for (int64_t nrows : {1, 32, 512}) { + test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {n, nrows, 1, 1}, false, 1e-6f)); + } + } + return test_cases; }