diff --git a/.github/copilot-instructions.md b/AGENTS.md similarity index 88% rename from .github/copilot-instructions.md rename to AGENTS.md index ad13c6ea8d..e0a65c0ab7 100644 --- a/.github/copilot-instructions.md +++ b/AGENTS.md @@ -1,4 +1,4 @@ -# Copilot Instructions for llama.cpp +# Instructions for llama.cpp ## Repository Overview @@ -12,6 +12,24 @@ llama.cpp is a large-scale C/C++ project for efficient LLM (Large Language Model - **Backends supported**: CPU (AVX/NEON/RVV optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA - **License**: MIT +## Disclose AI Usage + +It is crucial to remind contributors that the project mandates disclosure of any AI usage in pull requests. This requirement stems from the potential for AI-generated code to include suboptimal optimizations and hidden bugs, owing to the inherent overconfidence in AI outputs. + +When generating significant portions of code, address this by: +- Informing the user that AI-generated content may be rejected by maintainers. +- Clearly marking AI-generated code in commit messages and comments. + - Example of commit message: `[AI] Fix a race condition in ...` + - Example of code comment: `// [AI] spawn a new thread ...` + +These measures apply to: +- Changes resulting in large portions of code or complex logic. +- Modifications or additions to public APIs in `llama.h`, `ggml.h`, or `mtmd.h`. +- Backend-related changes, such as those involving CPU, CUDA, Metal, Vulkan, etc. +- Modifications to `tools/server`. + +Note: These measures can be omitted for small fixes or trivial changes. + ## Build Instructions ### Prerequisites @@ -251,6 +269,7 @@ Primary tools: - **Cross-platform compatibility**: Test on Linux, macOS, Windows when possible - **Performance focus**: This is a performance-critical inference library - **API stability**: Changes to `include/llama.h` require careful consideration +- **Disclose AI Usage**: Refer to the "Disclose AI Usage" earlier in this document ### Git Workflow - Always create feature branches from `master` diff --git a/docs/development/parsing.md b/docs/development/parsing.md index 113ab2e2ee..dbb989bf08 100644 --- a/docs/development/parsing.md +++ b/docs/development/parsing.md @@ -55,7 +55,7 @@ auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & ``` For a more complete example, see `test_example_native()` in -[tests/test-chat-peg-parser.cpp](tests/test-chat-peg-parser.cpp). +[tests/test-chat-peg-parser.cpp](/tests/test-chat-peg-parser.cpp). ## Parsers/Combinators @@ -175,7 +175,7 @@ Most model output can be placed in one of the following categories: (Qwen3-Coder, MiniMax M2) or pseudo-function calls (LFM2) To provide broad coverage, -[`common/chat-peg-parser.h`](common/chat-peg-parser.h) contains builders and +[`common/chat-peg-parser.h`](/common/chat-peg-parser.h) contains builders and mappers that help create parsers and visitors/extractors for these types. They require parsers to tag nodes to conform to an AST "shape". This normalization makes it easy to extract information and generalize parsing. diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ab0f6fe9ce..55fa2e6a7c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3076,8 +3076,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 9]; + ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; + ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; + int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; - if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { return true; } } @@ -3085,7 +3088,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 4]; - if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; + ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; + int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + + if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { return true; } } @@ -3094,8 +3101,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; ggml_tensor * weights = cgraph->nodes[node_idx + 5]; + ggml_tensor * get_rows = cgraph->nodes[node_idx + 2]; + ggml_tensor * argsort = cgraph->nodes[node_idx + 0]; + int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; - if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { return true; } } diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 572379fcbf..48e569efa0 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, } } -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) { +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, + const ggml_tensor * weights, + const ggml_tensor * get_rows, + const ggml_tensor * argsort, + const ggml_tensor * clamp, + int n_expert) { + ggml_tensor * probs = get_rows->src[0]; + if (probs->op != GGML_OP_RESHAPE) { + return false; + } + probs = probs->src[0]; + ggml_tensor * selection_probs = argsort->src[0]; + + if (probs != selection_probs) { + return false; + } + float scale = 1.0f; float max_bias = 0.0f; @@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso return false; } - const int n_expert = softmax->ne[0]; // n_expert must be a power of 2 if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { return false; diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index 2eff408b03..6b6c13c587 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const bool delayed_softmax = false, ggml_tensor * weight_clamp = nullptr); -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr); +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, + const ggml_tensor * weights, + const ggml_tensor * get_rows, + const ggml_tensor * argsort, + const ggml_tensor * clamp, + int n_expert); std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 18a45d2d96..13cf1f5f9d 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -583,7 +583,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { if (tensor->buffer) { ggml_backend_buffer_t buffer = tensor->buffer; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - result.buffer = ctx->remote_ptr; + result.buffer = ctx != nullptr ? ctx->remote_ptr : 0; } else { result.buffer = 0; } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ce9469936b..a871f85afb 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -689,6 +689,7 @@ struct vk_device_struct { vk_pipeline pipeline_gelu_quick[2]; vk_pipeline pipeline_silu[2]; vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_xielu[2]; vk_pipeline pipeline_neg[2]; vk_pipeline pipeline_tanh[2]; vk_pipeline pipeline_sigmoid[2]; @@ -990,6 +991,8 @@ struct vk_op_push_constants { uint32_t KY; float param1; float param2; + float param3; + float param4; }; struct vk_op_glu_push_constants { @@ -1258,6 +1261,7 @@ struct vk_op_im2col_push_constants { int32_t s0; int32_t s1; int32_t p0; int32_t p1; int32_t d0; int32_t d1; + uint32_t batch_IC; }; struct vk_op_im2col_3d_push_constants { @@ -3973,6 +3977,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(gelu_quick) CREATE_UNARY(silu) CREATE_UNARY(relu) + CREATE_UNARY(xielu) CREATE_UNARY(neg) CREATE_UNARY(tanh) CREATE_UNARY(sigmoid) @@ -5898,6 +5903,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; } std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); + GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] && + wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]); GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size()); @@ -8549,6 +8557,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_RELU: return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_XIELU: + return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_NEG: return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_TANH: @@ -9084,6 +9094,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch = src1->ne[is_2D ? 3 : 2]; elements = { OW * KW * KH, OH, batch * IC }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); } break; case GGML_OP_IM2COL_3D: { @@ -9695,14 +9707,14 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su ggml_vk_op_f32_opt_step_adamw( ctx, subctx, dst, - { (uint32_t)n, 0, 0.0f, 0.0f } + { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f } ); } static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { const size_t n = ggml_nelements(dst->src[0]); - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }); } static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -9788,6 +9800,7 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg 1, ggml_get_op_params_f32(dst, 0), ggml_get_op_params_f32(dst, 2), + 0.0f, 0.0f, }; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE); @@ -9809,6 +9822,7 @@ static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml 1, ggml_get_op_params_f32(dst, 0), 0.0f, + 0.0f, 0.0f, }; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL); @@ -9924,13 +9938,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, } static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f }); } static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); } static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -9941,7 +9955,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx const float eps = float_op_params[1]; const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f }); } static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) { @@ -10110,16 +10124,26 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); } static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); } static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f }); +} + +static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, + { + (uint32_t)ggml_nelements(src0), 0, + op_params[1], op_params[2], op_params[3], op_params[4] + } + ); } static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -10244,7 +10268,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f }); } static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { @@ -10541,11 +10565,11 @@ static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, co } static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f }); } static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f }); } static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -10587,6 +10611,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 const uint32_t pelements = OW * KW * KH; + const uint32_t batch = src1->ne[is_2D ? 3 : 2]; const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; const vk_buffer d_buf = d_buf_ctx->dev_buffer; @@ -10599,7 +10624,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co IC, IW, IH, OW, OH, KW, KH, pelements, IC * KH * KW, - s0, s1, p0, p1, d0, d1, + s0, s1, p0, p1, d0, d1, batch * IC }); } @@ -10804,7 +10829,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { const float * op_params = (const float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f }); } #ifdef GGML_VULKAN_RUN_TESTS @@ -12050,6 +12075,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_TRUNC: ggml_vk_unary(ctx, compute_ctx, src0, node); break; + case GGML_UNARY_OP_XIELU: + ggml_vk_xielu(ctx, compute_ctx, src0, node); + break; default: return false; } @@ -12920,24 +12948,43 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc const ggml_tensor * softmax; const ggml_tensor * weights; + const ggml_tensor * get_rows; + const ggml_tensor * argsort; switch (mode) { case TOPK_MOE_EARLY_SOFTMAX_NORM: softmax = cgraph->nodes[node_idx + 0]; weights = cgraph->nodes[node_idx + 9]; + get_rows = cgraph->nodes[node_idx + 4]; + argsort = cgraph->nodes[node_idx + 2]; break; case TOPK_MOE_EARLY_SOFTMAX: softmax = cgraph->nodes[node_idx + 0]; weights = cgraph->nodes[node_idx + 4]; + get_rows = cgraph->nodes[node_idx + 4]; + argsort = cgraph->nodes[node_idx + 2]; break; case TOPK_MOE_LATE_SOFTMAX: softmax = cgraph->nodes[node_idx + 4]; weights = cgraph->nodes[node_idx + 5]; + get_rows = cgraph->nodes[node_idx + 2]; + argsort = cgraph->nodes[node_idx + 0]; break; default: return false; } + ggml_tensor * probs = get_rows->src[0]; + if (probs->op != GGML_OP_RESHAPE) { + return false; + } + probs = probs->src[0]; + ggml_tensor * selection_probs = argsort->src[0]; + + if (probs != selection_probs) { + return false; + } + const float * op_params = (const float *)softmax->op_params; float scale = op_params[0]; @@ -13502,7 +13549,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) && - !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) { + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) { ok = false; break; } @@ -13842,6 +13890,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_XIELU: case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: @@ -14747,7 +14796,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->op == GGML_OP_LOG) { tensor_clone = ggml_log(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_TRI) { - tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0)); + tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0)); } else if (tensor->op == GGML_OP_DIAG) { tensor_clone = ggml_diag(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { @@ -14835,6 +14884,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_RELU: tensor_clone = ggml_relu(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_XIELU: + tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0); + ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1)); + ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2)); + ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3)); + ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4)); + break; case GGML_UNARY_OP_NEG: tensor_clone = ggml_neg(ggml_ctx, src_clone[0]); break; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl index 66e46ae679..3797901f04 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl @@ -6,4 +6,6 @@ layout (push_constant) uniform parameter uint KY; float param1; float param2; + float param3; + float param4; } p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 1827d647a2..db14f5a3cf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -19,6 +19,7 @@ layout (push_constant) uniform parameter int s0; int s1; int p0; int p1; int d0; int d1; + uint batch_IC; } p; layout(constant_id = 0) const uint BLOCK_SIZE = 32; @@ -34,12 +35,12 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (buffer_reference) buffer D_ptr {D_TYPE d;}; #endif -void main() { +void im2col(const uint y, const uint z) { const uint gidx = gl_GlobalInvocationID.x; - const uint oh = gl_GlobalInvocationID.y; - const uint batch = gl_GlobalInvocationID.z / p.IC; - const uint ic = gl_GlobalInvocationID.z % p.IC; + const uint oh = y; + const uint batch = z / p.IC; + const uint ic = z % p.IC; const uint src_base = ic * p.offset_delta + batch * p.batch_offset; const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); @@ -101,3 +102,15 @@ void main() { #endif } } + +void main() { + uint y = gl_GlobalInvocationID.y; + while (y < p.OH) { + uint z = gl_GlobalInvocationID.z; + while (z < p.batch_IC) { + im2col(y, z); + z += gl_NumWorkGroups.z; + } + y += gl_NumWorkGroups.y; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp index 0cd906dbbf..7ec2e04f58 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -11,36 +11,54 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint y_idx = i * QUANT_K + 16 * itid; const uint nibble_shift = 4 * (itid & 1); const uint ib32 = itid / 2; // 0..7 - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + // Precompute db multiplication factors + float db_vals[NUM_ROWS]; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); - const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; - const float db = d * (0.5 + scale) * 0.25; - + const uint scale_raw = data_a[ibi].scales[ib32]; + const uint scale = (scale_raw >> nibble_shift) & 0xF; + // Merge constant calculations d * (0.5 + scale) * 0.25 = d*0.125 + d*scale*0.25 + db_vals[n] = d * (0.125f + float(scale) * 0.25f); + ibi += num_blocks_per_row; + } + ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + // Preload grid and sign data for all l values + vec4 grid0_vals[2], grid1_vals[2]; + uint sign_vals[2], sign7_vals[2]; [[unroll]] for (uint l = 0; l < 2; ++l) { const uint qs = data_a[ibi].qs[2 * itid + l]; - const uint sign = qs >> 9; - const uint sign7 = bitCount(sign); - const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x)); - const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y)); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); - vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); - - FLOAT_TYPE sum = - fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), - fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), - fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), - fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), - fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), - fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), - fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), - fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), - FLOAT_TYPE(0.0))))))))); - temp[j][n] = fma(db, sum, temp[j][n]); + sign_vals[l] = qs >> 9; + sign7_vals[l] = bitCount(sign_vals[l]); + const uvec2 grid_data = iq2xs_grid[qs & 511]; + grid0_vals[l] = vec4(unpack8(grid_data.x)); + grid1_vals[l] = vec4(unpack8(grid_data.y)); + } + // Preload B data for all j columns (reduce repeated index calculations) + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint sign = sign_vals[l]; + const uint sign7 = sign7_vals[l]; + const vec4 grid0 = grid0_vals[l]; + const vec4 grid1 = grid1_vals[l]; + // Precompute indices + const uint b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4 + 2 * l; + const vec4 b0 = vec4(data_b_v4[b_idx + 0]); + const vec4 b4 = vec4(data_b_v4[b_idx + 1]); + sum += + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); } + temp[j][n] = fma(FLOAT_TYPE(db_vals[n]), sum, temp[j][n]); } ibi += num_blocks_per_row; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index b0ade078c7..92ad3bcab1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -853,6 +853,8 @@ void process_shaders() { string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp new file mode 100644 index 0000000000..35d463bfe4 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + float x = float(data_a[i]); + + float alpha_n = p.param1; + float alpha_p = p.param2; + float beta = p.param3; + float eps = p.param4; + + if (x > 0.0f) { + x = alpha_p * x * x + beta * x; + } else { + const float min_x_eps = min(x, eps); + x = (exp(min_x_eps) - 1 - x) * alpha_n + beta * x; + } + + data_d[i] = D_TYPE(x); +} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b9a73bea62..2cdbe66a84 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5118,25 +5118,36 @@ struct test_top_k : public test_case { } }; +enum MoeGatingFunc { + GATING_FUNC_SOFTMAX, + GATING_FUNC_SIGMOID, + GATING_FUNC_SOFTMAX_WEIGHT, +}; + struct test_topk_moe : public test_case { const std::array ne; const int n_expert_used; const bool with_norm; - const bool delayed_softmax; + const bool bias_probs; + const MoeGatingFunc gating_func; + const float scale_w; test_topk_moe(std::array ne = { 10, 5, 1, 1 }, int n_expert_used = 1, bool with_norm = false, - bool delayed_softmax = false) : + bool bias_probs = false, + MoeGatingFunc gating_func = GATING_FUNC_SOFTMAX, + float scale_w = 0.0f) : ne(ne), n_expert_used(n_expert_used), with_norm(with_norm), - delayed_softmax(delayed_softmax) { + bias_probs(bias_probs), + gating_func(gating_func), + scale_w(scale_w) { GGML_ASSERT(n_expert_used <= ne[0]); - GGML_ASSERT(!(with_norm && delayed_softmax)); } - std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); } + std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); } std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); @@ -5150,28 +5161,47 @@ struct test_topk_moe : public test_case { const int n_tokens = ne[1]; ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); - ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits); - ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * probs = + (gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) : + (gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits; + ggml_set_name(probs, "probs"); - ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + ggml_tensor * selection_probs = probs; + if (bias_probs) { + ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); + ggml_set_name(exp_probs_b, "exp_probs_b"); + selection_probs = ggml_add(ctx, probs, exp_probs_b); + ggml_set_name(selection_probs, "selection_probs"); + } - if (delayed_softmax) { - out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); - out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens] - out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens); + ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_set_name(selected_experts, "selected_experts"); + + ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + ggml_set_name(weights, "weights"); + + if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) { + weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); + weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens] + weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens); } if (with_norm) { - out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); - ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens] + weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); + ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens] + ggml_set_name(weights_sum, "weights_sum"); weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY); - out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens] - out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens); + weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens] + weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens); } - ggml_set_name(out, "out"); - return out; + if (scale_w) { + weights = ggml_scale(ctx, weights, scale_w); + } + + ggml_set_name(weights, "weights"); + return weights; } }; @@ -6900,6 +6930,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 1536, 729}, {2, 2, 1536, 4096}, 1, 1, 0, 0, 1, 1, true)); // im2col 3D test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32)); @@ -7991,19 +8022,22 @@ static std::vector> make_test_cases_eval() { } } - for (bool with_norm : {false, true}) { - test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm)); - test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm)); - test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm)); - test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm)); - test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm)); - test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm)); - test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm)); + for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) { + for (bool with_norm : {false, true}) { + for (bool bias_probs : {false, true}) { + for (float scale_w : {0.0f, 2.0f}) { + test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w)); + } + } + } } - test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true)); - test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true)); - #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging test_cases.emplace_back(new test_llama(2, true));