feat: add --moe-n-expert flag for MoE expert count override (Hard Mask)

Add ability to reduce the number of active experts in MoE models at runtime,
providing significant speedup with minimal quality loss when using 50% of
default experts.

Implementation:
- Add moe_n_expert_override parameter to llama_context_params
- Add --moe-n-expert CLI flag to override n_expert_used
- Implement "Hard Mask" in build_moe_ffn() that slices expert tensors
- Uses ggml_view_2d/3d + ggml_cont to reduce actual computation

Benchmark results (AOCL BLIS 5.0, AMD EPYC 9655):
- Qwen3-Coder-480B-A35B: 2.5 → 3.7 t/s (48% speedup)
- GLM-4.6-355B-A32B: 2.2 → 3.0 t/s (36% speedup)
- Qwen3-Coder-30B-A3B: 26.6 → 33.6 t/s (26% speedup)
- Qwen3-VL-30B-A3B: 32.2 → 38.9 t/s (21% speedup)

Quality: Excellent at 50% experts, degraded at 25%, gibberish at 12.5%

Usage: llama-cli -m model.gguf --moe-n-expert 4 -p "prompt"

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
pestopoppa 2025-12-14 13:32:50 +01:00
parent 7bed317f53
commit 49162df87a
7 changed files with 63 additions and 15 deletions

View File

@ -1687,6 +1687,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.yarn_beta_fast = std::stof(value); params.yarn_beta_fast = std::stof(value);
} }
).set_env("LLAMA_ARG_YARN_BETA_FAST")); ).set_env("LLAMA_ARG_YARN_BETA_FAST"));
add_opt(common_arg(
{"--moe-n-expert"}, "N",
string_format("MoE: override number of active experts (default: %d = model default)\n"
"for MoE self-draft speculation, use 1 for draft context", params.moe_n_expert_override),
[](common_params & params, int value) {
params.moe_n_expert_override = value;
}
).set_env("LLAMA_ARG_MOE_N_EXPERT"));
add_opt(common_arg( add_opt(common_arg(
{"-gan", "--grp-attn-n"}, "N", {"-gan", "--grp-attn-n"}, "N",
string_format("group-attention factor (default: %d)", params.grp_attn_n), string_format("group-attention factor (default: %d)", params.grp_attn_n),

View File

@ -1328,6 +1328,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.moe_n_expert_override = params.moe_n_expert_override;
cparams.pooling_type = params.pooling_type; cparams.pooling_type = params.pooling_type;
cparams.attention_type = params.attention_type; cparams.attention_type = params.attention_type;
cparams.flash_attn_type = params.flash_attn_type; cparams.flash_attn_type = params.flash_attn_type;

View File

@ -321,6 +321,7 @@ struct common_params {
float yarn_beta_fast = -1.0f; // YaRN low correction dim float yarn_beta_fast = -1.0f; // YaRN low correction dim
float yarn_beta_slow = -1.0f; // YaRN high correction dim float yarn_beta_slow = -1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length int32_t yarn_orig_ctx = 0; // YaRN original context length
int32_t moe_n_expert_override = 0; // MoE self-draft: override n_expert_used (0 = use model default)
// offload params // offload params
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading std::vector<ggml_backend_dev_t> devices; // devices to use for offloading

View File

@ -340,6 +340,11 @@ extern "C" {
uint32_t yarn_orig_ctx; // YaRN original context size uint32_t yarn_orig_ctx; // YaRN original context size
float defrag_thold; // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default) float defrag_thold; // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default)
// MoE self-drafting: override n_expert_used for this context
// 0 = use model default, 1+ = force exactly N active experts
// Used for MoE self-draft speculation: draft context uses n=1, verify uses full
int32_t moe_n_expert_override;
ggml_backend_sched_eval_callback cb_eval; ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data; void * cb_eval_user_data;

View File

@ -97,6 +97,7 @@ llama_context::llama_context(
cparams.op_offload = params.op_offload; cparams.op_offload = params.op_offload;
cparams.kv_unified = params.kv_unified; cparams.kv_unified = params.kv_unified;
cparams.moe_n_expert_override = params.moe_n_expert_override;
{ {
const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
@ -2303,6 +2304,7 @@ llama_context_params llama_context_default_params() {
/*.yarn_beta_slow =*/ -1.0f, /*.yarn_beta_slow =*/ -1.0f,
/*.yarn_orig_ctx =*/ 0, /*.yarn_orig_ctx =*/ 0,
/*.defrag_thold =*/ -1.0f, /*.defrag_thold =*/ -1.0f,
/*.moe_n_expert_override =*/ 0,
/*.cb_eval =*/ nullptr, /*.cb_eval =*/ nullptr,
/*.cb_eval_user_data =*/ nullptr, /*.cb_eval_user_data =*/ nullptr,
/*.type_k =*/ GGML_TYPE_F16, /*.type_k =*/ GGML_TYPE_F16,

View File

@ -35,6 +35,10 @@ struct llama_cparams {
bool op_offload; bool op_offload;
bool kv_unified; bool kv_unified;
// MoE self-drafting: override n_expert_used
// 0 = use model default, 1+ = force exactly N active experts
int32_t moe_n_expert_override;
enum llama_pooling_type pooling_type; enum llama_pooling_type pooling_type;
ggml_backend_sched_eval_callback cb_eval; ggml_backend_sched_eval_callback cb_eval;

View File

@ -995,16 +995,38 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens] ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il); cb(weights, "ffn_moe_weights", il);
// HARD MASK: If moe_n_expert_override is set, slice tensors to only use first N experts
// This actually reduces computation by only loading/computing N experts instead of all n_expert_used
// Unlike soft mask (which zeros weights but still computes all experts), hard mask skips the computation entirely
int32_t n_expert_exec = n_expert_used; // Default: execute all selected experts
if (cparams.moe_n_expert_override > 0 && cparams.moe_n_expert_override < n_expert_used) {
n_expert_exec = cparams.moe_n_expert_override;
// Slice selected_experts from [n_expert_used, n_tokens] to [n_expert_exec, n_tokens]
// This causes ggml_mul_mat_id to only load and compute the first n_expert_exec experts
selected_experts = ggml_view_2d(ctx0, selected_experts, n_expert_exec, n_tokens,
selected_experts->nb[1], 0);
// Make contiguous for subsequent operations
selected_experts = ggml_cont(ctx0, selected_experts);
cb(selected_experts, "ffn_moe_topk_sliced", il);
// Slice weights from [1, n_expert_used, n_tokens] to [1, n_expert_exec, n_tokens]
weights = ggml_view_3d(ctx0, weights, 1, n_expert_exec, n_tokens,
weights->nb[1], weights->nb[2], 0);
// Make contiguous for subsequent reshape operations
weights = ggml_cont(ctx0, weights);
cb(weights, "ffn_moe_weights_sliced", il);
}
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) { if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); weights = ggml_reshape_2d(ctx0, weights, n_expert_exec, n_tokens);
weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens] weights = ggml_soft_max(ctx0, weights); // [n_expert_exec, n_tokens]
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_exec, n_tokens);
cb(weights, "ffn_moe_weights_softmax", il); cb(weights, "ffn_moe_weights_softmax", il);
} }
if (norm_w) { if (norm_w) {
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); weights = ggml_reshape_2d(ctx0, weights, n_expert_exec, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il); cb(weights_sum, "ffn_moe_weights_sum", il);
@ -1013,10 +1035,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY); weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
cb(weights_sum, "ffn_moe_weights_sum_clamped", il); cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_exec, n_tokens]
cb(weights, "ffn_moe_weights_norm", il); cb(weights, "ffn_moe_weights_norm", il);
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_exec, n_tokens);
} }
if (scale_w) { if (scale_w) {
weights = ggml_scale(ctx0, weights, w_scale); weights = ggml_scale(ctx0, weights, w_scale);
@ -1029,8 +1051,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
if (weight_before_ffn) { if (weight_before_ffn) {
// repeat cur to [n_embd, n_expert_used, n_tokens] // repeat cur to [n_embd, n_expert_exec, n_tokens]
ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1); ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_exec, n_tokens, 1);
cur = ggml_mul(ctx0, repeated, weights); cur = ggml_mul(ctx0, repeated, weights);
cb(cur, "ffn_moe_weighted", il); cb(cur, "ffn_moe_weighted", il);
} }
@ -1108,26 +1130,31 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
assert(n_expert_used > 0); // Determine actual expert count for aggregation
// When --moe-n-expert is set (hard mask mode), use n_expert_exec
// Otherwise use hparams.n_expert_used to avoid dynamic allocation issues during warmup
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
const uint32_t n_expert_agg = (cparams.moe_n_expert_override > 0)
? (uint32_t)n_expert_exec
: hparams.n_expert_used;
assert(n_expert_agg > 0);
// order the views before the adds // order the views before the adds
for (uint32_t i = 0; i < hparams.n_expert_used; ++i) { for (uint32_t i = 0; i < n_expert_agg; ++i) {
cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]); cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
ggml_build_forward_expand(gf, cur_experts[i]); ggml_build_forward_expand(gf, cur_experts[i]);
} }
// aggregate experts // aggregate experts
// note: here we explicitly use hparams.n_expert_used instead of n_expert_used
// to avoid potentially a large number of add nodes during warmup
// ref: https://github.com/ggml-org/llama.cpp/pull/14753
ggml_tensor * moe_out = cur_experts[0]; ggml_tensor * moe_out = cur_experts[0];
for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { for (uint32_t i = 1; i < n_expert_agg; ++i) {
moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
} }
if (hparams.n_expert_used == 1) { if (n_expert_agg == 1) {
// avoid returning a non-contiguous tensor // avoid returning a non-contiguous tensor
moe_out = ggml_cont(ctx0, moe_out); moe_out = ggml_cont(ctx0, moe_out);
} }