diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e1c78e3b18..01b15e4b9a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8490,8 +8490,18 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel): class NemotronHModel(GraniteHybridModel): """Hybrid mamba2/attention model from NVIDIA""" model_arch = gguf.MODEL_ARCH.NEMOTRON_H + is_moe: bool = False def __init__(self, *args, **kwargs): + # We have to determine the correct model architecture (MoE vs non-MoE) before + # calling the parent __init__. This is because the parent constructor + # uses self.model_arch to build the tensor name map, and all MoE-specific + # mappings would be missed if it were called with the default non-MoE arch. + hparams = ModelBase.load_hparams(args[0], self.is_mistral_format) + if "num_experts_per_tok" in hparams: + self.model_arch = gguf.MODEL_ARCH.NEMOTRON_H_MOE + self.is_moe = True + super().__init__(*args, **kwargs) # Save the top-level head_dim for later @@ -8503,9 +8513,11 @@ class NemotronHModel(GraniteHybridModel): # Update the ssm / attn / mlp layers # M: Mamba2, *: Attention, -: MLP + # MoE: + # M: Mamba2, *: Attention, E: Expert hybrid_override_pattern = self.hparams["hybrid_override_pattern"] self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"] - self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"] + self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == ("E" if self.is_moe else "-")] def get_attn_layers(self): hybrid_override_pattern = self.hparams["hybrid_override_pattern"] @@ -8521,10 +8533,28 @@ class NemotronHModel(GraniteHybridModel): # Set feed_forward_length # NOTE: This will trigger an override warning. This is preferrable to # duplicating all the parent logic - n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"]) - self.gguf_writer.add_feed_forward_length([ - n_ff if i in self._mlp_layers else 0 for i in range(self.block_count) - ]) + if not self.is_moe: + n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"]) + self.gguf_writer.add_feed_forward_length([ + n_ff if i in self._mlp_layers else 0 for i in range(self.block_count) + ]) + else: + moe_intermediate_size = self.hparams["moe_intermediate_size"] + self.gguf_writer.add_feed_forward_length([ + moe_intermediate_size if i in self._mlp_layers else 0 for i in range(self.block_count) + ]) + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["moe_shared_expert_intermediate_size"]) + self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"]) + self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_group_count(self.hparams["n_group"]) + + # number of experts used per token (top-k) + if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) def set_vocab(self): super().set_vocab() @@ -8532,7 +8562,81 @@ class NemotronHModel(GraniteHybridModel): # The tokenizer _does_ add a BOS token (via post_processor type # TemplateProcessing) but does not set add_bos_token to true in the # config, so we need to explicitly override it here. - self.gguf_writer.add_add_bos_token(True) + if not self.is_moe: + self.gguf_writer.add_add_bos_token(True) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if self.is_moe and bid is not None: + if name.endswith("mixer.gate.e_score_correction_bias"): + new_name = name.replace("e_score_correction_bias", "e_score_correction.bias") + mapped_name = self.map_tensor_name(new_name) + return [(mapped_name, data_torch)] + + if name.endswith("mixer.dt_bias"): + new_name = name.replace("dt_bias", "dt.bias") + mapped_name = self.map_tensor_name(new_name) + return [(mapped_name, data_torch)] + + if name.endswith("mixer.conv1d.weight"): + squeezed_data = data_torch.squeeze() + mapped_name = self.map_tensor_name(name) + return [(mapped_name, squeezed_data)] + + if name.endswith("mixer.A_log"): + transformed_data = -torch.exp(data_torch) + reshaped_data = transformed_data.squeeze().reshape(-1, 1) + mapped_name = self.map_tensor_name(name) + return [(mapped_name, reshaped_data)] + + if name.endswith("mixer.D"): + reshaped_data = data_torch.squeeze().reshape(-1, 1) + mapped_name = self.map_tensor_name(name) + return [(mapped_name, reshaped_data)] + + if name.endswith("mixer.norm.weight"): + reshaped_data = data_torch.reshape(8, 512) + mapped_name = self.map_tensor_name(name) + return [(mapped_name, reshaped_data)] + + if name.find("mixer.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 2: + # merge the experts into a single tensor + tensors: list[tuple[str, Tensor]] = [] + for w_name in ["down_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"backbone.layers.{bid}.mixer.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + return tensors + else: + return [] + + return super().modify_tensors(data_torch, name, bid) + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") @ModelBase.register("BailingMoeForCausalLM") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8ef4a23a10..5ca4efd043 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -413,6 +413,7 @@ class MODEL_ARCH(IntEnum): JAIS = auto() NEMOTRON = auto() NEMOTRON_H = auto() + NEMOTRON_H_MOE = auto() EXAONE = auto() EXAONE4 = auto() GRANITE = auto() @@ -786,6 +787,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.JAIS: "jais", MODEL_ARCH.NEMOTRON: "nemotron", MODEL_ARCH.NEMOTRON_H: "nemotron_h", + MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe", MODEL_ARCH.EXAONE: "exaone", MODEL_ARCH.EXAONE4: "exaone4", MODEL_ARCH.GRANITE: "granite", @@ -2529,6 +2531,33 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.NEMOTRON_H_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + # experts + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + # shared expert + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.EXAONE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index b320e2b4b2..610227231f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -379,6 +379,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.gate", # lfm2moe "model.layers.{bid}.mlp.router.gate", # afmoe "layers.{bid}.gate", # mistral-large + "backbone.layers.{bid}.mixer.gate", # nemotron-h-moe ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -392,6 +393,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.expert_bias", # afmoe "model.layers.{bid}.feed_forward.expert_bias", # lfm2moe "model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2 + "backbone.layers.{bid}.mixer.gate.e_score_correction" # nemotron-h-moe ), # Feed-forward up @@ -440,7 +442,7 @@ class TensorNameMap: "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx - "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe + "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) ernie4.5-moe, nemotron-h-moe (merged) "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe @@ -454,6 +456,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.down_proj", "model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan "layers.{bid}.shared_experts.w3", # mistral-large + "backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe ), MODEL_TENSOR.FFN_UP_CHEXP: ( @@ -548,7 +551,7 @@ class TensorNameMap: "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx - "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe + "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) ernie4.5-moe nemotron-h-moe (merged) "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 @@ -563,6 +566,7 @@ class TensorNameMap: "model.layers.{bid}.shared_mlp.output_linear", # granitemoe "model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan "layers.{bid}.shared_experts.w2", # mistral-large + "backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe ), MODEL_TENSOR.FFN_DOWN_CHEXP: ( @@ -706,6 +710,7 @@ class TensorNameMap: "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 "model.layers.{bid}.linear_attn.dt_proj", # qwen3next + "backbone.layers.{bid}.mixer.dt", # nemotron-h-moe ), MODEL_TENSOR.SSM_DT_NORM: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 64ad1b7769..05b12a6072 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -75,6 +75,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, + { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, { LLM_ARCH_RWKV6, "rwkv6" }, @@ -1763,6 +1764,39 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_NEMOTRON_H_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + // mamba(2) ssm layers + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + // attention layers + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + // dense FFN + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + // MoE FFN (for MoE layers) + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B,"blk.%d.exp_probs_b" }, + // MoE shared expert layer + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_EXAONE, { @@ -2817,6 +2851,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: return true; default: diff --git a/src/llama-arch.h b/src/llama-arch.h index e113180024..455658f5dc 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -79,6 +79,7 @@ enum llm_arch { LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON_H, + LLM_ARCH_NEMOTRON_H_MOE, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, LLM_ARCH_RWKV6, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8909bbfb95..b096b93a9a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1089,6 +1089,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_relu(ctx0, cur); cb(cur, "ffn_moe_relu", il); } break; + case LLM_FFN_RELU_SQR: + if (gate_exps) { + // TODO: add support for gated squared relu + GGML_ABORT("fatal error: gated squared relu not implemented"); + } else { + cur = ggml_relu(ctx0, cur); + cur = ggml_sqr(ctx0, cur); + cb(cur, "ffn_moe_relu_sqr", il); + } break; default: GGML_ABORT("fatal error"); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 050735afc0..15547403b4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -120,6 +120,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; @@ -1797,6 +1798,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); @@ -1812,7 +1814,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + switch (hparams.n_layer) { + case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B case 56: type = LLM_TYPE_9B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -5159,6 +5168,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { // mamba2 Mixer SSM params // NOTE: int64_t for tensor dimensions @@ -5169,6 +5179,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + // embeddings tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5218,12 +5231,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - } else { - // mlp layers - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } else { + if (n_expert != 0) { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + + } else { + // mlp layers + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } } } } break; @@ -6850,7 +6877,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_NEMOTRON_H) { + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); @@ -6905,7 +6933,8 @@ void llama_model::print_info() const { if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID) { + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); @@ -7086,7 +7115,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, if (arch == LLM_ARCH_FALCON_H1) { filter_attn = [&](int32_t) { return true; }; filter_recr = [&](int32_t) { return true; }; - } else if (arch == LLM_ARCH_NEMOTRON_H) { + } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { filter_attn = [&](int32_t il) { return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; @@ -7457,6 +7486,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: { llm = std::make_unique(*this, params); } break; @@ -7741,6 +7771,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values diff --git a/src/llama-model.h b/src/llama-model.h index f8342cf2cb..c6eb953188 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -113,6 +113,7 @@ enum llm_type { LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_31B_A3_5B, LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_106B_A12B, // GLM-4.5-Air diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp index 5414348888..eb135e63f1 100644 --- a/src/models/nemotron-h.cpp +++ b/src/models/nemotron-h.cpp @@ -107,12 +107,41 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * } ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) { - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - NULL, NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * ffn_inp = cur; + ggml_tensor * moe_out = + build_moe_ffn(ffn_inp, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + nullptr, // no gate + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_RELU_SQR, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + cb(moe_out, "ffn_moe_out", il); + + ggml_tensor * ffn_shexp = build_ffn(ffn_inp, + model.layers[il].ffn_up_shexp, NULL, NULL, + NULL /* no gate */ , NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } cur = build_cvec(cur, il); cb(cur, "l_out", il);