llama : add support for Nemotron 3 Super (#20411)

* llama : add support for Nemotron 3 Super

This commit adds support for the Nemotron 3 Super model (120B.A12B)
enabling this model to be converted to GGUF format and run in llama.cpp.

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Matt Clayton <156335168+mattjcly@users.noreply.github.com>
This commit is contained in:
Daniel Bevenius 2026-03-11 19:27:53 +01:00 committed by GitHub
parent 76ea1c1c46
commit eaf1d7930c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 97 additions and 14 deletions

View File

@ -9743,20 +9743,35 @@ class NemotronHModel(GraniteHybridModel):
# M: Mamba2, *: Attention, -: MLP
# MoE:
# M: Mamba2, *: Attention, E: Expert
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == ("E" if self.is_moe else "-")]
pattern = self.hparams.get("hybrid_override_pattern") or self.hparams.get("layers_block_type")
if pattern is None:
self._ssm_layers = []
self._mlp_layers = []
elif isinstance(pattern, str):
self._ssm_layers = [i for i, val in enumerate(pattern) if val == "M"]
self._mlp_layers = [i for i, val in enumerate(pattern) if val == ("E" if self.is_moe else "-")]
else:
self._ssm_layers = [i for i, val in enumerate(pattern) if val == "mamba"]
self._mlp_layers = [i for i, val in enumerate(pattern) if val == "moe"]
def get_attn_layers(self):
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
pattern = self.hparams.get("hybrid_override_pattern") or self.hparams.get("layers_block_type")
if pattern is None:
return []
assert len(pattern) == self.block_count, f"Mismatch between pattern ({len(pattern)}) and block_count ({self.block_count})!"
if isinstance(pattern, str):
return [i for i, val in enumerate(pattern) if val == "*"]
return [i for i, val in enumerate(pattern) if val == "attention"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_key_length(self.head_dim)
self.gguf_writer.add_value_length(self.head_dim)
head_dim = self.head_dim
if head_dim is None:
raise ValueError("Could not find the attention head dim in config")
self.gguf_writer.add_key_length(head_dim)
self.gguf_writer.add_value_length(head_dim)
# Set feed_forward_length
# NOTE: This will trigger an override warning. This is preferable to
@ -9784,6 +9799,9 @@ class NemotronHModel(GraniteHybridModel):
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
if (latent_size := self.hparams.get("moe_latent_size")) is not None:
self.gguf_writer.add_moe_latent_size(latent_size)
def set_vocab(self):
super().set_vocab()
@ -9803,6 +9821,13 @@ class NemotronHModel(GraniteHybridModel):
name = name[len("language_model."):]
if self.is_moe and bid is not None:
# Skip Multi-Token Prediction (MTP) tensors. These are used for
# for speculative decoding but we don't include them in this model
# conversion. See https://github.com/ggml-org/llama.cpp/pull/18886
if "mtp" in name:
logger.info(f"gguf: Skipping MTP (Speculative) layer: {name}")
return []
if name.endswith("mixer.gate.e_score_correction_bias"):
new_name = name.replace("e_score_correction_bias", "e_score_correction.bias")
yield from ModelBase.modify_tensors(self, data_torch, new_name, bid)

View File

@ -9081,6 +9081,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(

View File

@ -125,6 +125,7 @@ class Keys:
EXPERT_GROUP_SCALE = "{arch}.expert_group_scale"
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
MOE_LATENT_SIZE = "{arch}.moe_latent_size"
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
NUM_DEEPSTACK_LAYERS = "{arch}.n_deepstack_layers"
POOLING_TYPE = "{arch}.pooling_type"
@ -543,6 +544,8 @@ class MODEL_TENSOR(IntEnum):
FFN_DOWN_CHEXP = auto()
FFN_UP_CHEXP = auto()
FFN_EXP_PROBS_B = auto()
MOE_LATENT_DOWN = auto() # nemotron 3 super
MOE_LATENT_UP = auto() # nemotron 3 super
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()
@ -986,6 +989,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.MOE_LATENT_DOWN: "blk.{bid}.ffn_latent_down", # nemotron 3 super
MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n
@ -2913,6 +2918,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
# expert latent
MODEL_TENSOR.MOE_LATENT_DOWN,
MODEL_TENSOR.MOE_LATENT_UP,
# shared expert
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,

View File

@ -859,6 +859,9 @@ class GGUFWriter:
def add_moe_every_n_layers(self, value: int) -> None:
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
def add_moe_latent_size(self, value: int) -> None:
self.add_uint32(Keys.LLM.MOE_LATENT_SIZE.format(arch=self.arch), value)
def add_nextn_predict_layers(self, count: int) -> None:
self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count)

View File

@ -571,6 +571,14 @@ class TensorNameMap:
"model.layers.{bid}.mlp.experts.gate_up_proj",
),
MODEL_TENSOR.MOE_LATENT_DOWN: (
"backbone.layers.{bid}.mixer.fc1_latent_proj", # nemotron 3 super
),
MODEL_TENSOR.MOE_LATENT_UP: (
"backbone.layers.{bid}.mixer.fc2_latent_proj", # nemotron 3 super
),
# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox

View File

@ -185,6 +185,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" },
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_MOE_LATENT_SIZE, "%s.moe_latent_size" },
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
{ LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
@ -365,6 +366,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
{ LLM_TENSOR_FFN_LATENT_DOWN, "blk.%d.ffn_latent_down" },
{ LLM_TENSOR_FFN_LATENT_UP, "blk.%d.ffn_latent_up" },
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
@ -1879,6 +1882,8 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_EXP_PROBS_B,
LLM_TENSOR_FFN_LATENT_DOWN,
LLM_TENSOR_FFN_LATENT_UP,
// MoE shared expert layer
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
@ -2754,6 +2759,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
// Nemotron 3 Super
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
};
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

View File

@ -189,6 +189,7 @@ enum llm_kv {
LLM_KV_EXPERT_GROUP_SCALE,
LLM_KV_EXPERTS_PER_GROUP,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_MOE_LATENT_SIZE,
LLM_KV_NEXTN_PREDICT_LAYERS,
LLM_KV_NUM_DEEPSTACK_LAYERS,
LLM_KV_POOLING_TYPE,
@ -385,6 +386,8 @@ enum llm_tensor {
LLM_TENSOR_FFN_GATE_CHEXPS,
LLM_TENSOR_FFN_UP_CHEXPS,
LLM_TENSOR_FFN_EXP_PROBS_B,
LLM_TENSOR_FFN_LATENT_DOWN,
LLM_TENSOR_FFN_LATENT_UP,
LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM,
LLM_TENSOR_LAYER_OUT_NORM,

View File

@ -89,6 +89,7 @@ struct llama_hparams {
bool expert_weights_norm = false;
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
uint32_t moe_every_n_layers = 0;
uint32_t moe_latent_size = 0;
uint32_t nextn_predict_layers = 0;
float f_norm_eps;

View File

@ -135,6 +135,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_100B_A6B: return "100B.A6B";
case LLM_TYPE_102B_A12B: return "102B.A12B";
case LLM_TYPE_106B_A12B: return "106B.A12B";
case LLM_TYPE_120B_A12B: return "120B.A12B";
case LLM_TYPE_122B_A10B: return "122B.A10B";
case LLM_TYPE_196B_A11B: return "196B.A11B";
case LLM_TYPE_230B_A10B: return "230B.A10B";
@ -1861,10 +1862,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false);
switch (hparams.n_layer) {
case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B
case 56: type = LLM_TYPE_9B; break;
case 88: type = LLM_TYPE_120B_A12B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
@ -5544,6 +5547,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_ssm_head = hparams.ssm_dt_rank;
const int64_t n_group = hparams.ssm_n_group;
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
const int64_t moe_n_embd = hparams.moe_latent_size > 0 ? hparams.moe_latent_size : n_embd;
// embeddings
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -5603,8 +5607,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0);
// MoE branch
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_latent_down = create_tensor(tn(LLM_TENSOR_FFN_LATENT_DOWN, "weight", i), {n_embd, moe_n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_latent_up = create_tensor(tn(LLM_TENSOR_FFN_LATENT_UP, "weight", i), {moe_n_embd, n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, moe_n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {moe_n_embd, n_ff_exp, n_expert}, 0);
// Shared expert branch
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);

View File

@ -126,6 +126,7 @@ enum llm_type {
LLM_TYPE_100B_A6B,
LLM_TYPE_102B_A12B, // Solar-Open
LLM_TYPE_106B_A12B, // GLM-4.5-Air
LLM_TYPE_120B_A12B, // Nemotron 3 Super
LLM_TYPE_122B_A10B, // Qwen3.5
LLM_TYPE_196B_A11B, // Step3.5-Flash
LLM_TYPE_230B_A10B, // Minimax M2
@ -294,6 +295,10 @@ struct llama_layer {
struct ggml_tensor * ffn_up_exps_b = nullptr;
struct ggml_tensor * ffn_gate_up_exps_b = nullptr;
// ff MoE latent proj
struct ggml_tensor * ffn_latent_down = nullptr;
struct ggml_tensor * ffn_latent_up = nullptr;
// ff shared expert (shexp)
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
struct ggml_tensor * ffn_gate_shexp = nullptr;

View File

@ -114,9 +114,18 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla
LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
ggml_tensor * ffn_inp = cur;
ggml_tensor * inp_emb = cur;
ggml_tensor * inp_latent = cur;
if (model.layers[il].ffn_latent_down) {
inp_latent = ggml_mul_mat(ctx0, model.layers[il].ffn_latent_down, cur);
}
ggml_tensor * router_logits = build_lora_mm(model.layers[il].ffn_gate_inp, cur);
cb(router_logits, "ffn_moe_logits", il);
ggml_tensor * moe_out =
build_moe_ffn(ffn_inp,
build_moe_ffn(inp_latent,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
nullptr, // no gate
@ -126,10 +135,15 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla
LLM_FFN_RELU_SQR, hparams.expert_weights_norm,
hparams.expert_weights_scale,
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
il);
il,
router_logits);
cb(moe_out, "ffn_moe_out", il);
ggml_tensor * ffn_shexp = build_ffn(ffn_inp,
if (model.layers[il].ffn_latent_up) {
moe_out = ggml_mul_mat(ctx0, model.layers[il].ffn_latent_up, moe_out);
}
ggml_tensor * ffn_shexp = build_ffn(inp_emb,
model.layers[il].ffn_up_shexp, NULL, NULL,
NULL /* no gate */ , NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,