diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f759e2d588..f9afb30424 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -571,6 +571,7 @@ extern "C" { GGML_OP_OPT_STEP_SGD, GGML_OP_GLU, + GGML_OP_MOE_SUM, GGML_OP_COUNT, }; @@ -1666,6 +1667,15 @@ extern "C" { struct ggml_tensor * b, // source struct ggml_tensor * c); // row indices + // a TS [hidden_dim, n_expert_used, n_tokens] + // b TS [hidden_dim, n_tokens] + // + // Sum the outputs from multiple experts for MoE models + GGML_API struct ggml_tensor * ggml_moe_sum( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_expert_used); + GGML_API struct ggml_tensor * ggml_diag( struct ggml_context * ctx, struct ggml_tensor * a); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b003fe13fd..c26622027f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1997,6 +1997,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_glu(params, tensor); } break; + case GGML_OP_MOE_SUM: + { + ggml_compute_forward_moe_sum(params, tensor); + } break; case GGML_OP_GET_REL_POS: { ggml_compute_forward_get_rel_pos(params, tensor); @@ -2259,6 +2263,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { GGML_ABORT("fatal error"); } break; + case GGML_OP_MOE_SUM: + { + n_tasks = n_threads; + } break; case GGML_OP_SILU_BACK: case GGML_OP_MUL: case GGML_OP_DIV: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ce15b18ce0..b771be42c1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9678,6 +9678,110 @@ void ggml_compute_forward_glu( } } +// ggml_compute_forward_moe_sum + +template +static void ggml_compute_forward_moe_sum_impl(const ggml_tensor * src0, ggml_tensor * dst, + int64_t ir0, int64_t ir1) { + constexpr auto src_to_f32 = type_conversion_table::to_f32; + constexpr auto f32_to_dst = type_conversion_table::from_f32; + + const int64_t hidden_dim = src0->ne[0]; + const int64_t n_expert_used = src0->ne[1]; + + const src_t * src = (const src_t *)src0->data; + dst_t * dst_data = (dst_t *)dst->data; + + const size_t nb_expert = src0->nb[1] / sizeof(src_t); + const size_t nb_token_src = src0->nb[2] / sizeof(src_t); + const size_t nb_token_dst = dst->nb[1] / sizeof(dst_t); + + // Process tokens [ir0, ir1) assigned to this thread + // Initialize dst region to zero first + for (int64_t t = ir0; t < ir1; t++) { + dst_t * dst_token = dst_data + t * nb_token_dst; + for (int64_t h = 0; h < hidden_dim; h++) { + dst_token[h] = f32_to_dst(0.0f); + } + } + + // Accumulate each expert's contribution + // Loop order: expert -> token -> hidden_dim for better cache locality + for (int64_t e = 0; e < n_expert_used; e++) { + for (int64_t t = ir0; t < ir1; t++) { + const src_t * src_token = src + t * nb_token_src + e * nb_expert; + dst_t * dst_token = dst_data + t * nb_token_dst; + + for (int64_t h = 0; h < hidden_dim; h++) { + dst_token[h] = f32_to_dst(src_to_f32(dst_token[h]) + src_to_f32(src_token[h])); + } + } + } +} + +// Specialized F32 implementation - no conversion needed, better cache locality +static void ggml_compute_forward_moe_sum_f32(const ggml_tensor * src0, ggml_tensor * dst, + int64_t ir0, int64_t ir1) { + const int64_t hidden_dim = src0->ne[0]; + const int64_t n_expert_used = src0->ne[1]; + + const float * src = (const float *)src0->data; + float * __restrict dst_data = (float *)dst->data; + + const size_t nb_expert = src0->nb[1] / sizeof(float); + const size_t nb_token_src = src0->nb[2] / sizeof(float); + const size_t nb_token_dst = dst->nb[1] / sizeof(float); + + // Initialize dst region to zero + for (int64_t t = ir0; t < ir1; t++) { + float * dst_token = dst_data + t * nb_token_dst; + for (int64_t h = 0; h < hidden_dim; h++) { + dst_token[h] = 0.0f; + } + } + + // Accumulate each expert's contribution + // Loop order: expert -> token -> hidden_dim for better cache locality + for (int64_t e = 0; e < n_expert_used; e++) { + for (int64_t t = ir0; t < ir1; t++) { + const float * src_token = src + t * nb_token_src + e * nb_expert; + float * __restrict dst_token = dst_data + t * nb_token_dst; + + // Use pointer arithmetic for better vectorization + const float * src_end = src_token + hidden_dim; + float * dst_ptr = dst_token; + const float * src_ptr = src_token; + + while (src_ptr < src_end) { + *dst_ptr++ += *src_ptr++; + } + } + } +} + +void ggml_compute_forward_moe_sum( + const ggml_compute_params * params, + ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(src0->type == dst->type); + + const auto [ir0, ir1] = get_thread_range(params, dst); + + // Dispatch based on data type + if (src0->type == GGML_TYPE_F32) { + ggml_compute_forward_moe_sum_f32(src0, dst, ir0, ir1); + } else if (src0->type == GGML_TYPE_F16) { + ggml_compute_forward_moe_sum_impl(src0, dst, ir0, ir1); + } else if (src0->type == GGML_TYPE_BF16) { + ggml_compute_forward_moe_sum_impl(src0, dst, ir0, ir1); + } else { + GGML_ABORT("fatal error: unsupported type for moe_sum"); + } +} + // ggml_compute_forward_get_rel_pos static void ggml_compute_forward_get_rel_pos_f16( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee7976..93d4455303 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -96,6 +96,7 @@ void ggml_compute_forward_win_part(const struct ggml_compute_params * params, st void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_moe_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index eeb8625dbe..999e3f8114 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2574,6 +2574,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg return false; } break; + case GGML_OP_MOE_SUM: + ggml_cuda_op_moe_sum(ctx, dst); + break; case GGML_OP_NORM: ggml_cuda_op_norm(ctx, dst); break; @@ -4561,6 +4564,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } break; + case GGML_OP_MOE_SUM: + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op); case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { diff --git a/ggml/src/ggml-cuda/moesum.cu b/ggml/src/ggml-cuda/moesum.cu new file mode 100644 index 0000000000..d1c3b07345 --- /dev/null +++ b/ggml/src/ggml-cuda/moesum.cu @@ -0,0 +1,342 @@ +#include "moesum.cuh" + +template +__device__ __forceinline__ T ldg_cg(const T* p) { + return __ldg(p); +} + +union Pack16B { + uint4 v; + __half u16[8]; +}; + +template +__global__ void moe_sum_reduce_warp_token_vec_kernel( + const half* __restrict__ x, + half* __restrict__ y, + const int32_t token_num, + const int32_t hidden_dim, + const int32_t topk_num, + const int32_t stride_token, // in elements + const int32_t stride_topk, // in elements + const int32_t out_stride_token // in elements +) { + constexpr int VEC = 16; + constexpr int PACKS = VEC / 8; + + const int warp_id = threadIdx.x / 32; + const int lane = threadIdx.x % 32; + const int32_t t = blockIdx.y * WARPS_PER_BLOCK + warp_id; + if (t >= token_num) return; + + const int32_t n_chunks = hidden_dim / VEC; + + for (int32_t chunk = blockIdx.x * 32 + lane; chunk < n_chunks; chunk += (int32_t)gridDim.x * 32) { + const int32_t d = chunk * VEC; + const int32_t base = t * stride_token + d; + + float acc[VEC]; +#pragma unroll + for (int i = 0; i < VEC; ++i) + acc[i] = 0.f; + +#pragma unroll + for (int k = 0; k < topk_num; ++k) { +#pragma unroll + for (int p = 0; p < PACKS; ++p) { + const int32_t offset = base + (int32_t)k * stride_topk + p * 8; + Pack16B pack = {ldg_cg(reinterpret_cast(x + offset))}; + +#pragma unroll + for (int i = 0; i < 8; ++i) { + acc[p * 8 + i] += static_cast(pack.u16[i]); + } + } + } + +#pragma unroll + for (int p = 0; p < PACKS; ++p) { + Pack16B outp; +#pragma unroll + for (int i = 0; i < 8; ++i) { + outp.u16[i] = static_cast(acc[p * 8 + i]); + } + const int32_t dst = t * out_stride_token + d + p * 8; + *reinterpret_cast(y + dst) = outp.v; + } + } +} + +template +__global__ void moe_sum_reduce_warp_token_kernel( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ y, + const int32_t token_num, + const int32_t hidden_dim, + const int32_t stride_token, + const int32_t stride_topk, + const int32_t out_stride_token) { + const int warp_id = threadIdx.x / 32; + const int lane = threadIdx.x % 32; + const int32_t t = blockIdx.y * WARPS_PER_BLOCK + warp_id; + if (t >= token_num) return; + + for (int32_t d = blockIdx.x * 32 + lane; d < hidden_dim; d += gridDim.x * 32) { + float acc = 0.f; + const int32_t base = t * stride_token + d; + +#pragma unroll + for (int k = 0; k < TOPK; ++k) { + acc += static_cast(x[base + k * stride_topk]); + } + + y[t * out_stride_token + d] = static_cast(acc); + } +} + +template +__global__ void moe_sum_reduce_warp_token_kernel_general( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ y, + const int32_t token_num, + const int32_t hidden_dim, + const int32_t stride_token, + const int32_t stride_topk, + const int32_t out_stride_token, + const int topk_num) { + const int warp_id = threadIdx.x / 32; + const int lane = threadIdx.x % 32; + const int32_t t = blockIdx.y * WARPS_PER_BLOCK + warp_id; + if (t >= token_num) return; + + for (int32_t d = blockIdx.x * 32 + lane; d < hidden_dim; d += gridDim.x * 32) { + float acc = 0.f; + const int32_t base = t * stride_token + d; +#pragma unroll 1 + for (int k = 0; k < topk_num; ++k) { + acc += static_cast(x[base + k * stride_topk]); + } + + y[t * out_stride_token + d] = static_cast(acc); + } +} + +template +__global__ void moe_sum_reduce_kernel( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ y, + const int32_t token_num, + const int32_t hidden_dim, + const int32_t stride_token, + const int32_t stride_topk, + const int32_t out_stride_token) { + for (int t = blockIdx.y; t < token_num; t += gridDim.y) { + for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) { + const int32_t base = t * stride_token + d; + float acc = 0.f; + +#pragma unroll + for (int k = 0; k < TOPK; ++k) { + acc += static_cast(x[base + k * stride_topk]); + } + + y[t * out_stride_token + d] = static_cast(acc); + } + } +} + +// -------------------- general-topk fallback kernels -------------------- +// small-token +template +__global__ void moe_sum_reduce_kernel_general( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ y, + const int32_t token_num, + const int32_t hidden_dim, + const int32_t stride_token, + const int32_t stride_topk, + const int32_t out_stride_token, + const int topk_num) { + for (int t = blockIdx.y; t < token_num; t += gridDim.y) { + for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) { + const int32_t base = t * stride_token + d; + float acc = 0.f; + +#pragma unroll 1 + for (int k = 0; k < topk_num; ++k) { + acc += static_cast(x[base + k * stride_topk]); + } + + y[t * out_stride_token + d] = static_cast(acc); + } + } +} + +#define LAUNCH_SMALL_TOKEN_KERNEL(scalar_t, TOPK) \ + moe_sum_reduce_kernel<<>>( \ + static_cast(src0->data), \ + static_cast(dst->data), \ + token_num, \ + hidden_dim, \ + stride_token, \ + stride_topk, \ + out_stride_token); + +#define LAUNCH_GENERIC_KERNEL(scalar_t) \ + moe_sum_reduce_kernel_general \ + <<>>( \ + static_cast(src0->data), \ + static_cast(dst->data), \ + token_num, \ + hidden_dim, \ + stride_token, \ + stride_topk, \ + out_stride_token, \ + topk_num); + +#define LAUNCH_WARP_PER_TOKEN_KERNEL(scalar_t, TOPK) \ + moe_sum_reduce_warp_token_kernel \ + <<>>( \ + static_cast(src0->data), \ + static_cast(dst->data), \ + token_num, \ + hidden_dim, \ + stride_token, \ + stride_topk, \ + out_stride_token); + +#define LAUNCH_WARP_PER_TOKEN_GENERIC_KERNEL(scalar_t) \ + moe_sum_reduce_warp_token_kernel_general \ + <<>>( \ + static_cast(src0->data), \ + static_cast(dst->data), \ + token_num, \ + hidden_dim, \ + stride_token, \ + stride_topk, \ + out_stride_token, \ + topk_num); + +void ggml_cuda_op_moe_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + // [hidden_dim, n_experts_used, tokens] + ggml_tensor * src0 = dst->src[0]; + + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(src0->ne[0] == dst->ne[0]); + GGML_ASSERT(src0->ne[2] == dst->ne[1]); + + const int token_num = src0->ne[2]; + const int topk_num = src0->ne[1]; + const int hidden_dim = src0->ne[0]; + + const int stride_token = src0->nb[2] / src0->nb[0]; + const int stride_topk = src0->nb[1] / src0->nb[0]; + const int out_stride_token = dst->nb[1] / dst->nb[0]; + + auto stream = ctx.stream(); + + const bool fast_fp16_vec_ok = (src0->type == GGML_TYPE_F16) && + (token_num > 256) && (hidden_dim % 8 == 0); + if (fast_fp16_vec_ok) { + constexpr int WARPS_PER_BLOCK = 8; + constexpr int THREADS = WARPS_PER_BLOCK * 32; + + const int n_chunks = hidden_dim / 8; + int grid_x = (n_chunks + 32 - 1) / 32; + int grid_y = (token_num + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK; + + dim3 block(THREADS); + dim3 grid(grid_x, grid_y); + + moe_sum_reduce_warp_token_vec_kernel + <<>>( + static_cast(src0->data), + static_cast(dst->data), + token_num, + hidden_dim, + topk_num, + stride_token, + stride_topk, + out_stride_token); + CUDA_CHECK(cudaGetLastError()); + return; + } + + const bool per_token_use_one_warp = (token_num > 128); + if (!per_token_use_one_warp) { + // small token num + const int block_size = 256; + int grid_x = (hidden_dim + block_size - 1) / block_size; + int grid_y = token_num; + + dim3 block(block_size); + dim3 grid(grid_x, grid_y); + + if (src0->type == GGML_TYPE_F32) { + if (topk_num == 2) { + LAUNCH_SMALL_TOKEN_KERNEL(float, 2); + } else if (topk_num == 4) { + LAUNCH_SMALL_TOKEN_KERNEL(float, 4); + } else if (topk_num == 8) { + LAUNCH_SMALL_TOKEN_KERNEL(float, 8); + } else if (topk_num == 9) { + LAUNCH_SMALL_TOKEN_KERNEL(float, 9); + } else { + LAUNCH_GENERIC_KERNEL(float); + } + } else if (src0->type == GGML_TYPE_F16) { + if (topk_num == 2) { + LAUNCH_SMALL_TOKEN_KERNEL(half, 2); + } else if (topk_num == 4) { + LAUNCH_SMALL_TOKEN_KERNEL(half, 4); + } else if (topk_num == 8) { + LAUNCH_SMALL_TOKEN_KERNEL(half, 8); + } else if (topk_num == 9) { + LAUNCH_SMALL_TOKEN_KERNEL(half, 9); + } else { + LAUNCH_GENERIC_KERNEL(half); + } + } else { + GGML_ASSERT(false); + } + } else { + // warp-per-token + constexpr int WARPS_PER_BLOCK = 4; + constexpr int THREADS = WARPS_PER_BLOCK * 32; + + int grid_x = (hidden_dim + 32 - 1) / 32; + int grid_y = (token_num + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK; + dim3 block(THREADS); + dim3 grid(grid_x, grid_y); + + if (src0->type == GGML_TYPE_F32) { + if (topk_num == 2) { + LAUNCH_WARP_PER_TOKEN_KERNEL(float, 2); + } else if (topk_num == 4) { + LAUNCH_WARP_PER_TOKEN_KERNEL(float, 4); + } else if (topk_num == 8) { + LAUNCH_WARP_PER_TOKEN_KERNEL(float, 8); + } else if (topk_num == 9) { + LAUNCH_WARP_PER_TOKEN_KERNEL(float, 9); + } else { + LAUNCH_WARP_PER_TOKEN_GENERIC_KERNEL(float); + } + } else if (src0->type == GGML_TYPE_F16) { + if (topk_num == 2) { + LAUNCH_WARP_PER_TOKEN_KERNEL(half, 2); + } else if (topk_num == 4) { + LAUNCH_WARP_PER_TOKEN_KERNEL(half, 4); + } else if (topk_num == 8) { + LAUNCH_WARP_PER_TOKEN_KERNEL(half, 8); + } else if (topk_num == 9) { + LAUNCH_WARP_PER_TOKEN_KERNEL(half, 9); + } else { + LAUNCH_WARP_PER_TOKEN_GENERIC_KERNEL(half); + } + } else { + GGML_ASSERT(false); + } + } +} diff --git a/ggml/src/ggml-cuda/moesum.cuh b/ggml/src/ggml-cuda/moesum.cuh new file mode 100644 index 0000000000..b31b123f58 --- /dev/null +++ b/ggml/src/ggml-cuda/moesum.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_moe_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 500cb6b72f..69597092ff 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1045,9 +1045,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_SGD", "GLU", + "MOE_SUM", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1154,9 +1155,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "sgd(x)", "glu(x)", + "moe_sum(x, n)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -3017,6 +3019,22 @@ struct ggml_tensor * ggml_swiglu_oai( return result; } +// ggml_moe_sum + +struct ggml_tensor * ggml_moe_sum( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_expert_used) { + GGML_ASSERT(a->ne[1] == n_expert_used); + const int64_t ne[2] = {a->ne[0], a->ne[2]}; + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 2, ne); + + result->op = GGML_OP_MOE_SUM; + result->src[0] = a; + + return result; +} + // ggml_norm static struct ggml_tensor * ggml_norm_impl( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index cecdf47038..5a02897763 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2155,6 +2155,38 @@ struct test_swiglu_oai : public test_case { } }; +// GGML_OP_MOE_SUM +struct test_moe_sum : public test_case { + const ggml_type type; + const int64_t hidden_dim; + const int64_t n_expert_used; + const int64_t n_tokens; + + std::string vars() override { + return VARS_TO_STR4(type, hidden_dim, n_expert_used, n_tokens); + } + + // F16 has limited precision when summing multiple expert outputs + double max_nmse_err() override { + return type == GGML_TYPE_F16 ? 1e-6 : 1e-7; + } + + test_moe_sum(ggml_type type = GGML_TYPE_F32, + int64_t hidden_dim = 128, + int64_t n_expert_used = 4, + int64_t n_tokens = 16) + : type(type), hidden_dim(hidden_dim), n_expert_used(n_expert_used), n_tokens(n_tokens) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor_3d(ctx, type, hidden_dim, n_expert_used, n_tokens); + ggml_set_param(a); + ggml_set_name(a, "a"); + ggml_tensor * out = ggml_moe_sum(ctx, a, n_expert_used); + ggml_set_name(out, "out"); + return out; + } +}; + // GGML_OP_GET_ROWS struct test_get_rows : public test_case { const ggml_type type; @@ -7025,6 +7057,16 @@ static std::vector> make_test_cases_eval() { } } + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (int64_t n_expert_used : {2, 4, 8}) { + for (int64_t hidden_dim : {64, 128, 256, 4096}) { + for (int64_t n_tokens : {16, 32, 128, 256}) { + test_cases.emplace_back(new test_moe_sum(type, hidden_dim, n_expert_used, n_tokens)); + } + } + } + } + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_Q4_0}) { test_cases.emplace_back(new test_get_rows(type, 300*256, 5, 4, 1, 2, false)); test_cases.emplace_back(new test_get_rows(type, 256, 80000, 70000, 2, 1, false));