diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6e6e618989..846f894e03 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -116,7 +116,8 @@ class ModelBase: split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, disable_mistral_community_chat_template: bool = False, - sentence_transformers_dense_modules: bool = False): + sentence_transformers_dense_modules: bool = False, + fuse_gate_up_exps: bool = False): if type(self) is ModelBase or \ type(self) is TextModel or \ type(self) is MmprojModel: @@ -135,6 +136,9 @@ class ModelBase: self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id self.sentence_transformers_dense_modules = sentence_transformers_dense_modules + self.fuse_gate_up_exps = fuse_gate_up_exps + self._gate_exp_buffer: dict[int, Tensor] = {} + self._up_exp_buffer: dict[int, Tensor] = {} self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id) self.metadata_override = metadata_override @@ -514,8 +518,36 @@ class ModelBase: raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses") def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - return [(self.map_tensor_name(name), data_torch)] + new_name = self.map_tensor_name(name) + + # Handle gate/up expert tensor fusion if enabled + if self.fuse_gate_up_exps and bid is not None: + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid): + self._gate_exp_buffer[bid] = data_torch + # Check if up_exps is already buffered for this layer + if bid in self._up_exp_buffer: + gate_data = self._gate_exp_buffer.pop(bid) + up_data = self._up_exp_buffer.pop(bid) + # gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd) + fused_data = torch.cat([gate_data, up_data], dim=1) + fused_name = f"blk.{bid}.ffn_gate_up_exps.weight" + logger.info(f"Fused gate_exps and up_exps for layer {bid}") + return [(fused_name, fused_data)] + return [] # Wait for up_exps + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid): + self._up_exp_buffer[bid] = data_torch + # Check if gate_exps is already buffered for this layer + if bid in self._gate_exp_buffer: + gate_data = self._gate_exp_buffer.pop(bid) + up_data = self._up_exp_buffer.pop(bid) + # gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd) + fused_data = torch.cat([gate_data, up_data], dim=1) + fused_name = f"blk.{bid}.ffn_gate_up_exps.weight" + logger.info(f"Fused gate_exps and up_exps for layer {bid}") + return [(fused_name, fused_data)] + return [] # Wait for gate_exps + + return [(new_name, data_torch)] def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused @@ -11121,6 +11153,11 @@ def parse_args() -> argparse.Namespace: "Default these modules are not included.") ) + parser.add_argument( + "--fuse-gate-up-exps", action="store_true", + help="Fuse gate_exps and up_exps tensors into a single gate_up_exps tensor for MoE models.", + ) + args = parser.parse_args() if not args.print_supported_models and args.model is None: parser.error("the following arguments are required: model") @@ -11258,7 +11295,8 @@ def main() -> None: split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, small_first_shard=args.no_tensor_first_split, remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template, - sentence_transformers_dense_modules=args.sentence_transformers_dense_modules + sentence_transformers_dense_modules=args.sentence_transformers_dense_modules, + fuse_gate_up_exps=args.fuse_gate_up_exps ) if args.vocab_only: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 31273b2b5a..ef0437e74a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -511,6 +511,7 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_EXP = auto() FFN_DOWN_EXP = auto() FFN_UP_EXP = auto() + FFN_GATE_UP_EXP = auto() FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() @@ -937,6 +938,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps", MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n @@ -3115,6 +3117,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_SINKS, MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_UP_EXP, MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 84aa868809..f1801c7f74 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -555,6 +555,10 @@ class TensorNameMap: "model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe ), + MODEL_TENSOR.FFN_GATE_UP_EXP: ( + "model.layers.{bid}.mlp.experts.gate_up_proj", # gpt-oss + ), + # Feed-forward down MODEL_TENSOR.FFN_DOWN: ( "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a54bc1956a..b2bdf3e250 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -335,6 +335,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_UP_EXPS, "blk.%d.ffn_gate_up_exps" }, { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, @@ -1497,6 +1498,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_UP_EXPS, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, @@ -2088,6 +2090,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_OUT, LLM_TENSOR_ATTN_SINKS, LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, @@ -2434,6 +2437,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_GATE_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 270d28b16a..91af4d29c9 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -357,6 +357,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b3198b7e3a..6450bac26f 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1070,7 +1070,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps) const { return build_moe_ffn( cur, gate_inp, /* gate_inp_b */ nullptr, @@ -1086,7 +1087,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( w_scale, gating_op, il, - probs_in + probs_in, + gate_up_exps ); } @@ -1109,7 +1111,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps, + ggml_tensor * gate_up_exps_b) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN @@ -1248,30 +1252,52 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_weighted", il); } - ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); - - if (up_exps_b) { - up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); - cb(up, "ffn_moe_up_biased", il); - } - + ggml_tensor * up = nullptr; ggml_tensor * experts = nullptr; - if (gate_exps) { - cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + + if (gate_up_exps) { + // merged gate_up path: one mul_mat_id, then split into gate and up views + ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens] + cb(gate_up, "ffn_moe_gate_up", il); + + if (gate_up_exps_b) { + gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts); + cb(gate_up, "ffn_moe_gate_up_biased", il); + } + + const int64_t n_ff = gate_up->ne[0] / 2; + cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0); cb(cur, "ffn_moe_gate", il); + up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]); + cb(up, "ffn_moe_up", il); } else { - cur = up; + // separate gate and up path + up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + if (up_exps_b) { + up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); + cb(up, "ffn_moe_up_biased", il); + } + + if (gate_exps) { + cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate", il); + } else { + cur = up; + } + + if (gate_exps_b) { + cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); + cb(cur, "ffn_moe_gate_biased", il); + } } - if (gate_exps_b) { - cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); - cb(cur, "ffn_moe_gate_biased", il); - } + const bool has_gate = gate_exps || gate_up_exps; switch (type_op) { case LLM_FFN_SILU: - if (gate_exps) { + if (has_gate) { cur = ggml_swiglu_split(ctx0, cur, up); cb(cur, "ffn_moe_swiglu", il); } else { @@ -1279,7 +1305,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_silu", il); } break; case LLM_FFN_GELU: - if (gate_exps) { + if (has_gate) { cur = ggml_geglu_split(ctx0, cur, up); cb(cur, "ffn_moe_geglu", il); } else { @@ -1295,7 +1321,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_swiglu_oai", il); } break; case LLM_FFN_RELU: - if (gate_exps) { + if (has_gate) { cur = ggml_reglu_split(ctx0, cur, up); cb(cur, "ffn_moe_reglu", il); } else { @@ -1303,7 +1329,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_relu", il); } break; case LLM_FFN_RELU_SQR: - if (gate_exps) { + if (has_gate) { // TODO: add support for gated squared relu GGML_ABORT("fatal error: gated squared relu not implemented"); } else { diff --git a/src/llama-graph.h b/src/llama-graph.h index 4090d8116c..160de313eb 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -786,7 +786,8 @@ struct llm_graph_context { float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr) const; ggml_tensor * build_moe_ffn( ggml_tensor * cur, @@ -807,7 +808,9 @@ struct llm_graph_context { float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr, + ggml_tensor * gate_up_exps_b = nullptr) const; // // inputs diff --git a/src/llama-model.cpp b/src/llama-model.cpp index cc784e1cb0..d6cfd8f635 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4983,9 +4983,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // try merged gate_up first, fall back to separate gate and up + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + } // Shared expert branch layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); @@ -6527,9 +6532,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // try merged gate_up first, fall back to separate gate and up + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + } // bias layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0); @@ -6538,9 +6548,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); - layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0); - layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + + // try merged gate_up bias first, fall back to separate gate and up + layer.ffn_gate_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "bias", i), {n_ff_exp * 2, n_expert}, TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_up_exps_b == nullptr) { + layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + } } } break; case LLM_ARCH_LFM2: diff --git a/src/llama-model.h b/src/llama-model.h index d1de16e3f2..fad89146a4 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -275,14 +275,16 @@ struct llama_layer { struct ggml_tensor * ffn_up_enc = nullptr; // ff MoE - struct ggml_tensor * ffn_gate_inp = nullptr; - struct ggml_tensor * ffn_gate_exps = nullptr; - struct ggml_tensor * ffn_down_exps = nullptr; - struct ggml_tensor * ffn_up_exps = nullptr; - struct ggml_tensor * ffn_gate_inp_b = nullptr; - struct ggml_tensor * ffn_gate_exps_b = nullptr; - struct ggml_tensor * ffn_down_exps_b = nullptr; - struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_exps = nullptr; + struct ggml_tensor * ffn_down_exps = nullptr; + struct ggml_tensor * ffn_up_exps = nullptr; + struct ggml_tensor * ffn_gate_up_exps = nullptr; + struct ggml_tensor * ffn_gate_inp_b = nullptr; + struct ggml_tensor * ffn_gate_exps_b = nullptr; + struct ggml_tensor * ffn_down_exps_b = nullptr; + struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_up_exps_b = nullptr; // ff shared expert (shexp) struct ggml_tensor * ffn_gate_inp_shexp = nullptr; diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 297dca5136..0bee3eb303 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -217,7 +217,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr LLM_FFN_SILU, hparams.expert_weights_norm, hparams.expert_weights_scale, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, - il); + il, + nullptr, + model.layers[il].ffn_gate_up_exps); cb(moe_out, "ffn_moe_out", il); // FFN shared expert diff --git a/src/models/openai-moe-iswa.cpp b/src/models/openai-moe-iswa.cpp index dbe3ca1851..8f0096a30d 100644 --- a/src/models/openai-moe-iswa.cpp +++ b/src/models/openai-moe-iswa.cpp @@ -88,16 +88,18 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, // MoE branch cur = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b, - model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b, - model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b, + model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b, + model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b, nullptr, n_expert, n_expert_used, LLM_FFN_SWIGLU_OAI_MOE, false, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT, - il); + il, + nullptr, // probs_in + model.layers[il].ffn_gate_up_exps, model.layers[il].ffn_gate_up_exps_b); cb(cur, "ffn_moe_out", il); cur = ggml_add(ctx0, cur, ffn_inp);