CUDA: add float4 vectorized load/store for rms_norm_f32

Add a separate rms_norm_f32_vec4 kernel using float4 (128-bit) vectorized
memory loads/stores. Host-side dispatch routes to the vec4 kernel when
ncols is divisible by 4 and strides are aligned; otherwise falls back to
the original rms_norm_f32 kernel which is completely untouched.

A separate kernel is used instead of a runtime branch inside the existing
kernel to avoid register pressure and instruction cache pollution that
would degrade the scalar path (~22% measured regression with runtime if).

Performance (A100, nrows=512, test-backend-ops perf, 5-run avg):
  [512,512]:  427 -> 624 GB/s (+46%)
  [768,512]:  626 -> 850 GB/s (+36%)
  [1024,512]: 495 -> 645 GB/s (+30%)
  [2048,512]: 911 -> 1171 GB/s (+28%)
  [3072,512]: 1220 -> 1490 GB/s (+22%)
  [5120,512]: 1668 -> 1815 GB/s (+9%)
  Scalar fallback (4097,512): 1476 -> 1471 GB/s (no regression)

Correctness: RMS_NORM 17/17, RMS_NORM_MUL_ADD 30/30,
ADD_RMS_NORM 25/25, RMS_NORM_MUL_ROPE 72/72 passed.
This commit is contained in:
Te-Hsiu Huang 2026-03-13 09:42:15 -07:00 committed by Michael Huang
parent 57819b8d4b
commit 0586379302
2 changed files with 156 additions and 18 deletions

View File

@ -150,6 +150,93 @@ static __global__ void rms_norm_f32(const float * x,
}
}
template <int block_size, bool do_multiply = false, bool do_add = false>
static __global__ void rms_norm_f32_vec4(const float * x,
float * dst,
const int ncols,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const float eps,
const float * mul = nullptr,
const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0,
const int64_t mul_stride_sample = 0,
const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
const float * add = nullptr,
const int64_t add_stride_row = 0,
const int64_t add_stride_channel = 0,
const int64_t add_stride_sample = 0,
const uint3 add_ncols_packed = make_uint3(0, 0, 0),
const uint3 add_nrows_packed = make_uint3(0, 0, 0),
const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
if constexpr (do_multiply) {
const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
}
if constexpr (do_add) {
const int add_row = fastmodulo(row, add_nrows_packed);
const int add_channel = fastmodulo(channel, add_nchannels_packed);
const int add_sample = fastmodulo(sample, add_nsamples_packed);
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
}
const int ncols4 = ncols / 4;
const float4 * x4 = reinterpret_cast<const float4 *>(x);
float tmp = 0.0f;
for (int col4 = tid; col4 < ncols4; col4 += block_size) {
const float4 v = x4[col4];
tmp += v.x*v.x + v.y*v.y + v.z*v.z + v.w*v.w;
}
// sum up partial sums
extern __shared__ float s_sum[];
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
float4 * dst4 = reinterpret_cast<float4 *>(dst);
for (int col4 = tid; col4 < ncols4; col4 += block_size) {
float4 v = x4[col4];
if constexpr (do_multiply && do_add) {
const int b = col4*4;
v.x = scale * v.x * mul[fastmodulo(b, mul_ncols_packed)] + add[fastmodulo(b, add_ncols_packed)];
v.y = scale * v.y * mul[fastmodulo(b+1, mul_ncols_packed)] + add[fastmodulo(b+1, add_ncols_packed)];
v.z = scale * v.z * mul[fastmodulo(b+2, mul_ncols_packed)] + add[fastmodulo(b+2, add_ncols_packed)];
v.w = scale * v.w * mul[fastmodulo(b+3, mul_ncols_packed)] + add[fastmodulo(b+3, add_ncols_packed)];
} else if constexpr (do_multiply) {
const int b = col4*4;
v.x = scale * v.x * mul[fastmodulo(b, mul_ncols_packed)];
v.y = scale * v.y * mul[fastmodulo(b+1, mul_ncols_packed)];
v.z = scale * v.z * mul[fastmodulo(b+2, mul_ncols_packed)];
v.w = scale * v.w * mul[fastmodulo(b+3, mul_ncols_packed)];
} else {
v.x *= scale; v.y *= scale; v.z *= scale; v.w *= scale;
}
dst4[col4] = v;
}
}
template <int block_size>
static __global__ void rms_norm_back_f32(
const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
@ -298,12 +385,24 @@ static void rms_norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
const bool use_vec4 = (ncols % 4 == 0)
&& (stride_row % 4 == 0) && (stride_channel % 4 == 0) && (stride_sample % 4 == 0)
&& (reinterpret_cast<uintptr_t>(x) % sizeof(float4) == 0)
&& (reinterpret_cast<uintptr_t>(dst) % sizeof(float4) == 0);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
if (use_vec4) {
rms_norm_f32_vec4<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
if (use_vec4) {
rms_norm_f32_vec4<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
}
@ -339,6 +438,10 @@ static void rms_norm_mul_f32_cuda(const float * x,
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return;
}
const bool use_vec4 = (ncols % 4 == 0)
&& (stride_row % 4 == 0) && (stride_channel % 4 == 0) && (stride_sample % 4 == 0)
&& (reinterpret_cast<uintptr_t>(x) % sizeof(float4) == 0)
&& (reinterpret_cast<uintptr_t>(dst) % sizeof(float4) == 0);
if (add == nullptr) {
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
@ -346,14 +449,26 @@ static void rms_norm_mul_f32_cuda(const float * x,
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
if (use_vec4) {
rms_norm_f32_vec4<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
} else {
rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
}
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
if (use_vec4) {
rms_norm_f32_vec4<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
} else {
rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
}
}
} else {
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
@ -367,18 +482,34 @@ static void rms_norm_mul_f32_cuda(const float * x,
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
if (use_vec4) {
rms_norm_f32_vec4<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
} else {
rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
}
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
if (use_vec4) {
rms_norm_f32_vec4<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
} else {
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
}
}
}
}

View File

@ -8980,6 +8980,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 1024, 1)); // 4h PP-1024
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1, 1, false, true)); // KDA PP-64
// rms_norm: float4 vectorized path (ncols divisible by 4) and scalar fallback (ncols not divisible by 4)
for (uint32_t n : {4, 128, 256, 512, 768, 1024, 2048, 3072, 4096, 5120, 8192, 3, 5, 13, 127, 4097}) {
for (int64_t nrows : {1, 32, 512}) {
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {n, nrows, 1, 1}, false, 1e-6f));
}
}
return test_cases;
}