Merge 0586379302 into 9e2e2198b0
This commit is contained in:
commit
cb58f77fdf
|
|
@ -150,6 +150,93 @@ static __global__ void rms_norm_f32(const float * x,
|
|||
}
|
||||
}
|
||||
|
||||
template <int block_size, bool do_multiply = false, bool do_add = false>
|
||||
static __global__ void rms_norm_f32_vec4(const float * x,
|
||||
float * dst,
|
||||
const int ncols,
|
||||
const int64_t stride_row,
|
||||
const int64_t stride_channel,
|
||||
const int64_t stride_sample,
|
||||
const float eps,
|
||||
const float * mul = nullptr,
|
||||
const int64_t mul_stride_row = 0,
|
||||
const int64_t mul_stride_channel = 0,
|
||||
const int64_t mul_stride_sample = 0,
|
||||
const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
|
||||
const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
|
||||
const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
|
||||
const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
|
||||
const float * add = nullptr,
|
||||
const int64_t add_stride_row = 0,
|
||||
const int64_t add_stride_channel = 0,
|
||||
const int64_t add_stride_sample = 0,
|
||||
const uint3 add_ncols_packed = make_uint3(0, 0, 0),
|
||||
const uint3 add_nrows_packed = make_uint3(0, 0, 0),
|
||||
const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
|
||||
const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
|
||||
const int nrows = gridDim.x;
|
||||
const int nchannels = gridDim.y;
|
||||
const int row = blockIdx.x;
|
||||
const int channel = blockIdx.y;
|
||||
const int sample = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
if constexpr (do_multiply) {
|
||||
const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
|
||||
const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
|
||||
const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
|
||||
mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
|
||||
}
|
||||
if constexpr (do_add) {
|
||||
const int add_row = fastmodulo(row, add_nrows_packed);
|
||||
const int add_channel = fastmodulo(channel, add_nchannels_packed);
|
||||
const int add_sample = fastmodulo(sample, add_nsamples_packed);
|
||||
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
|
||||
}
|
||||
|
||||
const int ncols4 = ncols / 4;
|
||||
const float4 * x4 = reinterpret_cast<const float4 *>(x);
|
||||
float tmp = 0.0f;
|
||||
|
||||
for (int col4 = tid; col4 < ncols4; col4 += block_size) {
|
||||
const float4 v = x4[col4];
|
||||
tmp += v.x*v.x + v.y*v.y + v.z*v.z + v.w*v.w;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
extern __shared__ float s_sum[];
|
||||
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
|
||||
|
||||
const float mean = tmp / ncols;
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
float4 * dst4 = reinterpret_cast<float4 *>(dst);
|
||||
for (int col4 = tid; col4 < ncols4; col4 += block_size) {
|
||||
float4 v = x4[col4];
|
||||
if constexpr (do_multiply && do_add) {
|
||||
const int b = col4*4;
|
||||
v.x = scale * v.x * mul[fastmodulo(b, mul_ncols_packed)] + add[fastmodulo(b, add_ncols_packed)];
|
||||
v.y = scale * v.y * mul[fastmodulo(b+1, mul_ncols_packed)] + add[fastmodulo(b+1, add_ncols_packed)];
|
||||
v.z = scale * v.z * mul[fastmodulo(b+2, mul_ncols_packed)] + add[fastmodulo(b+2, add_ncols_packed)];
|
||||
v.w = scale * v.w * mul[fastmodulo(b+3, mul_ncols_packed)] + add[fastmodulo(b+3, add_ncols_packed)];
|
||||
} else if constexpr (do_multiply) {
|
||||
const int b = col4*4;
|
||||
v.x = scale * v.x * mul[fastmodulo(b, mul_ncols_packed)];
|
||||
v.y = scale * v.y * mul[fastmodulo(b+1, mul_ncols_packed)];
|
||||
v.z = scale * v.z * mul[fastmodulo(b+2, mul_ncols_packed)];
|
||||
v.w = scale * v.w * mul[fastmodulo(b+3, mul_ncols_packed)];
|
||||
} else {
|
||||
v.x *= scale; v.y *= scale; v.z *= scale; v.w *= scale;
|
||||
}
|
||||
dst4[col4] = v;
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rms_norm_back_f32(
|
||||
const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
|
||||
|
|
@ -298,12 +385,24 @@ static void rms_norm_f32_cuda(
|
|||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||
const dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
const bool use_vec4 = (ncols % 4 == 0)
|
||||
&& (stride_row % 4 == 0) && (stride_channel % 4 == 0) && (stride_sample % 4 == 0)
|
||||
&& (reinterpret_cast<uintptr_t>(x) % sizeof(float4) == 0)
|
||||
&& (reinterpret_cast<uintptr_t>(dst) % sizeof(float4) == 0);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(256, 1, 1);
|
||||
rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
if (use_vec4) {
|
||||
rms_norm_f32_vec4<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
} else {
|
||||
rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
}
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
if (use_vec4) {
|
||||
rms_norm_f32_vec4<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
} else {
|
||||
rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -339,6 +438,10 @@ static void rms_norm_mul_f32_cuda(const float * x,
|
|||
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
|
||||
return;
|
||||
}
|
||||
const bool use_vec4 = (ncols % 4 == 0)
|
||||
&& (stride_row % 4 == 0) && (stride_channel % 4 == 0) && (stride_sample % 4 == 0)
|
||||
&& (reinterpret_cast<uintptr_t>(x) % sizeof(float4) == 0)
|
||||
&& (reinterpret_cast<uintptr_t>(dst) % sizeof(float4) == 0);
|
||||
if (add == nullptr) {
|
||||
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
|
||||
const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
|
||||
|
|
@ -346,14 +449,26 @@ static void rms_norm_mul_f32_cuda(const float * x,
|
|||
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(256, 1, 1);
|
||||
rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
||||
if (use_vec4) {
|
||||
rms_norm_f32_vec4<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
||||
} else {
|
||||
rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
||||
}
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
||||
if (use_vec4) {
|
||||
rms_norm_f32_vec4<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
||||
} else {
|
||||
rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
|
||||
|
|
@ -367,18 +482,34 @@ static void rms_norm_mul_f32_cuda(const float * x,
|
|||
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(256, 1, 1);
|
||||
rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
||||
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
||||
add_nchannels_packed, add_nsamples_packed);
|
||||
if (use_vec4) {
|
||||
rms_norm_f32_vec4<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
||||
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
||||
add_nchannels_packed, add_nsamples_packed);
|
||||
} else {
|
||||
rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
||||
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
||||
add_nchannels_packed, add_nsamples_packed);
|
||||
}
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
||||
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
||||
add_nchannels_packed, add_nsamples_packed);
|
||||
if (use_vec4) {
|
||||
rms_norm_f32_vec4<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
||||
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
||||
add_nchannels_packed, add_nsamples_packed);
|
||||
} else {
|
||||
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
|
||||
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
|
||||
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
|
||||
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
|
||||
add_nchannels_packed, add_nsamples_packed);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8982,6 +8982,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 1024, 1)); // 4h PP-1024
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1, 1, false, true)); // KDA PP-64
|
||||
|
||||
// rms_norm: float4 vectorized path (ncols divisible by 4) and scalar fallback (ncols not divisible by 4)
|
||||
for (uint32_t n : {4, 128, 256, 512, 768, 1024, 2048, 3072, 4096, 5120, 8192, 3, 5, 13, 127, 4097}) {
|
||||
for (int64_t nrows : {1, 32, 512}) {
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {n, nrows, 1, 1}, false, 1e-6f));
|
||||
}
|
||||
}
|
||||
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue