CUDA: Factor out and re-use `block_reduce` function (#18785)

* CUDA: Refactor and expose two_stage_warp_reduce_* function

* Use `two_stage_warp_reduce` also in softmax kernel, move smem out of it

Moving smem out of `__device__` function to `__global__` function
allows for explicit smem reuse, as either compiler or cuda rt seem to not
free it afterwards (`cudaFuncSetAttribute` fails when not accounting for
it once for each call to two_stage_warp_reduce)

* Update ggml/src/ggml-cuda/common.cuh

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* Use two_stage_warp_reduce in group_norm_f32

* Use two_stage_warp_reduce in rms_norm_f32

* Fix smem calculation which expects bytes

* Make `two_stage_warp_reduce` accept all values warp_reduce accepts

Also integrate it into norm_f32 function

* Use two_stage_warp_reduce in l2_norm_f32

* Use type traits for block reduction for better legibility

Also adresss other requests by @am17an such as variable renaming

* Make norm tests cover all cuda paths

* Mark columns % WARP_SIZE !=0 as supported for RMS_NORM_BACK

Unit-tests passed locally, let's see if they pass in the CI as well

* Use `enum class` for `block_reduce_method`

This is more type-safe than plain enum

* Rename variables as suggested in code review by @am17an

* Rename two_stage_warp_reduce -> block_reduce

* Fix trailing whitespace in common.cuh

* Make condition of static_assert type-dependent

This delays evaluation until the template is actually instantiated.
Otherwise, some compilers may evaluate the assert when parsing the
template, resulting in build errors as observed here:

https://github.com/ggml-org/llama.cpp/actions/runs/20960323123/job/60235530068?pr=18785

* Inline definitions

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
This commit is contained in:
Oliver Simons 2026-01-15 03:44:54 +01:00 committed by GitHub
parent d98b548120
commit 36f0132464
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 125 additions and 191 deletions

View File

@ -530,6 +530,86 @@ static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
#endif // FP16_AVAILABLE
}
enum class block_reduce_method {
MAX,
SUM,
};
template<block_reduce_method method_t, typename T>
struct block_reduce_policy;
template <typename T, typename... Ts>
inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
template<typename...>
inline constexpr bool ggml_cuda_dependent_false_v = false;
template <typename T> struct block_reduce_policy<block_reduce_method::SUM, T> {
static __device__ T reduce(T val) {
if constexpr(is_any<T, float, float2, half2, int>) {
return warp_reduce_sum(val);
} else {
static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
}
}
static __device__ T sentinel() {
if constexpr (std::is_same_v<T, float>) {
return 0.0f;
} else if constexpr (std::is_same_v<T, float2>) {
return make_float2(0.0f, 0.0f);
} else if constexpr (std::is_same_v<T, half2>) {
return make_half2(0.0f, 0.0f);
} else if constexpr (std::is_same_v<T, int>) {
return 0;
} else {
static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
}
}
};
template <typename T> struct block_reduce_policy<block_reduce_method::MAX, T> {
static __device__ T reduce(T val) {
if constexpr (is_any<T, float, half2>) {
return warp_reduce_max(val);
} else {
static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
}
}
static __device__ T sentinel() {
if constexpr (std::is_same_v<T, float>) {
return -INFINITY;
} else if constexpr (std::is_same_v<T, half2>) {
return make_half2(-INFINITY, -INFINITY);
} else {
static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
}
}
};
template <block_reduce_method reduce_method_t, const unsigned int block_size_template = 0, typename T>
static __device__ T block_reduce(T val, T * shared_vals) {
val = block_reduce_policy<reduce_method_t, T>::reduce(val);
const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
if (block_size > WARP_SIZE) {
assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0);
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
shared_vals[warp_id] = val;
}
__syncthreads();
val = block_reduce_policy<reduce_method_t, T>::sentinel();
if (lane_id < (static_cast<int>(block_size) / WARP_SIZE)) {
val = shared_vals[lane_id];
}
return block_reduce_policy<reduce_method_t, T>::reduce(val);
}
return val;
}
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#ifdef FP16_AVAILABLE

View File

@ -4551,7 +4551,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_L2_NORM:
return true;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
return ggml_is_contiguous(op->src[0]);
break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:

View File

@ -25,19 +25,8 @@ static __global__ void norm_f32(
}
// sum up partial sums
mean_var = warp_reduce_sum(mean_var);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float2 s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
__syncthreads();
mean_var = s_sum[lane_id];
mean_var = warp_reduce_sum(mean_var);
}
extern __shared__ float2 s_sum2[];
mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2);
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
@ -61,19 +50,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
tmp += x[j];
}
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
extern __shared__ float s_sum[];
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float mean = tmp / group_size;
tmp = 0.0f;
@ -84,18 +62,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
tmp += xi * xi;
}
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float variance = tmp / group_size;
const float scale = rsqrtf(variance + eps);
@ -163,22 +130,8 @@ static __global__ void rms_norm_f32(const float * x,
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = tid / WARP_SIZE;
const int lane_id = tid % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = 0.0f;
if (lane_id < (block_size / WARP_SIZE)) {
tmp = s_sum[lane_id];
}
tmp = warp_reduce_sum(tmp);
}
extern __shared__ float s_sum[];
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
@ -306,19 +259,8 @@ static __global__ void l2_norm_f32(
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
extern __shared__ float s_sum[];
tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
@ -337,7 +279,7 @@ static void norm_f32_cuda(
norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
@ -348,7 +290,7 @@ static void group_norm_f32_cuda(
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
} else {
const dim3 block_dims(1024, 1, 1);
group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
}
}
@ -358,10 +300,10 @@ static void rms_norm_f32_cuda(
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
@ -404,12 +346,12 @@ static void rms_norm_mul_f32_cuda(const float * x,
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
}
@ -425,14 +367,14 @@ static void rms_norm_mul_f32_cuda(const float * x,
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
if (ncols < 1024) {
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
@ -460,7 +402,7 @@ static void l2_norm_f32_cuda(
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}

View File

@ -28,22 +28,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
}
// sum up partial sums
sum = warp_reduce_sum(sum);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = sum;
}
__syncthreads();
sum = 0.0f;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
sum = s_sum[lane_id];
}
sum = warp_reduce_sum(sum);
}
__shared__ float shared_vals[32];
sum = block_reduce<block_reduce_method::SUM>(sum, shared_vals);
if (col != 0) {
return;

View File

@ -75,9 +75,6 @@ static __global__ void soft_max_f32(
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
extern __shared__ float data_soft_max_f32[];
@ -102,21 +99,7 @@ static __global__ void soft_max_f32(
}
// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
__syncthreads();
max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}
max_val = block_reduce<block_reduce_method::MAX, block_size_template>(max_val, buf_iw);
float tmp = 0.0f; // partial sum
@ -134,22 +117,7 @@ static __global__ void soft_max_f32(
}
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__syncthreads();
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = tmp;
}
__syncthreads();
tmp = buf_iw[lane_id];
tmp = warp_reduce_sum(tmp);
}
tmp = block_reduce<block_reduce_method::SUM, block_size_template>(tmp, buf_iw);
if (sinks) {
tmp += expf(sinks[i02] - max_val);
@ -169,50 +137,6 @@ static __global__ void soft_max_f32(
}
}
// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
static __device__ float two_stage_warp_reduce_max(float val) {
val = warp_reduce_max(val);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float local_vals[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
local_vals[warp_id] = val;
}
__syncthreads();
val = -INFINITY;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
val = local_vals[lane_id];
}
return warp_reduce_max(val);
} else {
return val;
}
}
static __device__ float two_stage_warp_reduce_sum(float val) {
val = warp_reduce_sum(val);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float local_vals[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
local_vals[warp_id] = val;
}
__syncthreads();
val = 0.0f;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
val = local_vals[lane_id];
}
return warp_reduce_sum(val);
} else {
return val;
}
}
// TODO: Template to allow keeping ncols in registers if they fit
static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
float * __restrict__ dst,
@ -230,6 +154,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
float local_max = -INFINITY;
const int step_size = gridDim.x * blockDim.x;
__shared__ float shared_vals[32];
// Compute thread-local max
for (int col = col_start; col < p.ncols;) {
@ -246,7 +171,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
}
// Compute CTA-level max
local_max = two_stage_warp_reduce_max(local_max);
local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
// Store CTA-level max to GMEM
if (tid == 0) {
@ -261,7 +186,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
} else {
local_max = -INFINITY;
}
local_max = two_stage_warp_reduce_max(local_max);
local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
// Compute softmax dividends, accumulate divisor
float tmp_expf = 0.0f;
@ -284,7 +209,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
}
// Reduce divisor within CTA
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
// Store CTA-level sum to GMEM
if (tid == 0) {
@ -298,7 +223,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
} else {
tmp_expf = 0.0f;
}
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
// Divide dividend by global sum + store data
for (int col = col_start; col < p.ncols;) {

View File

@ -7482,25 +7482,29 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
test_cases.emplace_back(new test_silu_back());
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f }) {
for (uint32_t n : { 64, 1025 }) {
for (bool v : { false, true }) {
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
}
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
}
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
// in-place tests
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true));
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
for (float eps : { 0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f }) {
for (uint32_t n : { 64, 1025 }) {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
}
}
for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
for (bool multi_add : {false, true}) {
@ -7524,9 +7528,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
}
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
for (int64_t d_conv : {3, 4, 9}) {
for (int64_t d_inner: {1024, 1536, 2048}) {
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));