From 079feab9e3efee1d6d4ca370eac50f156e2dc6e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 14 Feb 2026 22:22:32 +0100 Subject: [PATCH] convert : ensure all models handle new experts count (#19621) * ensure all models handle new experts count * revert removal for PhiMoeModel, does not inherit from base --- convert_hf_to_gguf.py | 76 +++++++++++++++---------------------------- 1 file changed, 26 insertions(+), 50 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7f341a58dd..0f614e4df3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2726,8 +2726,6 @@ class AfmoeModel(LlamaModel): super().set_gguf_parameters() # MoE parameters - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None: self.gguf_writer.add_expert_shared_count(n_shared_experts) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: @@ -2749,7 +2747,7 @@ class AfmoeModel(LlamaModel): # Handle expert weights - they're already merged in the HF format # process the experts separately if name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -4197,8 +4195,6 @@ class Qwen2MoeModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") @@ -4243,7 +4239,7 @@ class Qwen2MoeModel(TextModel): return if name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -4994,13 +4990,13 @@ class PhiMoeModel(Phi3MiniModel): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) - self.gguf_writer.add_expert_count(self.hparams["num_local_experts"]) + self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"])) + self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"])) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("block_sparse_moe.experts") != -1: - n_experts = self.hparams["num_local_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -5412,7 +5408,7 @@ class KimiLinearModel(TextModel): # process the experts separately if name.find("block_sparse_moe.experts") != -1: - n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=False) + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -6007,12 +6003,13 @@ class NomicBertModel(BertModel): if "mlp.experts.bias" in name: return # Explicitly return. + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) if "mlp.experts.mlp.w1" in name: - data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"]) name += ".weight" if "mlp.experts.mlp.w2" in name: - data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"]) data_torch = data_torch.transpose(1, 2) name += ".weight" @@ -6022,7 +6019,6 @@ class NomicBertModel(BertModel): super().set_gguf_parameters() if self.is_moe: self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"]) - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"]) def _is_tokenizer_xlmroberta(self) -> bool: @@ -7259,8 +7255,8 @@ class JambaModel(TextModel): self.gguf_writer.add_ssm_state_size(d_state) self.gguf_writer.add_ssm_time_step_rank(dt_rank) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) - self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"])) + self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"])) self.gguf_writer.add_file_type(self.ftype) _experts: list[dict[str, Tensor]] | None = None @@ -7278,7 +7274,7 @@ class JambaModel(TextModel): # process the experts separately if ".feed_forward.experts." in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None @@ -7426,8 +7422,6 @@ class OlmoeModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_layer_norm_rms_eps(1e-5) - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) _experts: list[dict[str, Tensor]] | None = None @@ -7435,7 +7429,7 @@ class OlmoeModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -8016,10 +8010,6 @@ class MiniMaxM2Model(TextModel): model_arch = gguf.MODEL_ARCH.MINIMAXM2 _experts_cache: dict[int, dict[str, Tensor]] = {} - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.hparams["num_experts"] = self.hparams["num_local_experts"] - def set_gguf_parameters(self): super().set_gguf_parameters() @@ -8032,7 +8022,7 @@ class MiniMaxM2Model(TextModel): # merge expert weights if 'experts' in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None expert_cache = self._experts_cache.setdefault(bid, {}) @@ -9237,7 +9227,6 @@ class ExaoneMoEModel(Exaone4Model): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) moe_intermediate_size = self.hparams["moe_intermediate_size"] num_shared_experts = self.hparams["num_shared_experts"] self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) @@ -9278,7 +9267,7 @@ class ExaoneMoEModel(Exaone4Model): name = name.replace("e_score_correction_bias", "e_score_correction.bias") if name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -9429,7 +9418,7 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel): # case, the model architecture needs to be updated to a standard # "granite" or "granitemoe" model if not self._ssm_layers: - has_experts = self.find_hparam(["num_experts_per_tok"], optional=True) + has_experts = self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True) new_arch = ( gguf.MODEL_ARCH.GRANITE_MOE if has_experts else @@ -9727,7 +9716,6 @@ class BailingMoeModel(TextModel): self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_weights_scale(1.0) - self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) @@ -9761,7 +9749,7 @@ class BailingMoeModel(TextModel): yield from super().modify_tensors(v,self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), bid) return elif name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -9832,7 +9820,6 @@ class BailingMoeV2Model(TextModel): self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"])) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) - self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) @@ -9843,7 +9830,7 @@ class BailingMoeV2Model(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if "mlp.experts" in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -9889,8 +9876,6 @@ class GroveMoeModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") @@ -9911,7 +9896,7 @@ class GroveMoeModel(TextModel): # process the experts separately if name.find("chunk_experts") != -1: - n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) // 2 # see add_experts_per_group assert bid is not None if self._chunk_experts is None: @@ -9938,7 +9923,7 @@ class GroveMoeModel(TextModel): else: return elif name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -10331,7 +10316,6 @@ class HunYuanMoEModel(TextModel): super().set_gguf_parameters() hparams = self.hparams - self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) moe_intermediate_size = hparams["moe_intermediate_size"] @@ -10374,7 +10358,7 @@ class HunYuanMoEModel(TextModel): return if name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -10416,16 +10400,9 @@ class LLaDAMoEModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) - if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size) - # number of experts used per token (top-k) - if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: - self.gguf_writer.add_expert_used_count(n_experts_used) - self.gguf_writer.add_mask_token_id(156895) self.gguf_writer.add_causal_attention(False) self.gguf_writer.add_diffusion_shift_logits(False) @@ -10436,7 +10413,7 @@ class LLaDAMoEModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -10773,7 +10750,6 @@ class LFM2MoeModel(TextModel): super().set_gguf_parameters() - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"]) self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) @@ -10794,7 +10770,7 @@ class LFM2MoeModel(TextModel): # merge expert weights if 'experts' in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None expert_cache = self._experts_cache.setdefault(bid, {}) @@ -10904,9 +10880,9 @@ class SmallThinkerModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None: + if (n_experts := self.hparams.get("moe_num_primary_experts")) is not None: self.gguf_writer.add_expert_count(n_experts) - if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None: + if (n_experts_used := self.hparams.get("moe_num_active_primary_experts")) is not None: self.gguf_writer.add_expert_used_count(n_experts_used) if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) @@ -10931,7 +10907,7 @@ class SmallThinkerModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("experts") != -1: - n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts")) + n_experts = self.hparams.get("moe_num_primary_experts") or self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: