model : add BailingMoeV2 support (#16063)

* add BailingMoeV2 support

* update llm types

* undo

* undo

* update llm types

* add model collection link

* update

* almost working

* correct group selection and rename n_group_exp

* avoid large top_k and use argmax instead for now

if we had something like argmax2 that would be equivalent, but this works fine until then

* poke

* skip group selection when there are no tokens

* fix 1T conversion

* hopefully fixed expert group selection

third time's the charm?

* make expert group selection generally available

The new LLaDA2Moe model uses this method too, make it generally available regardless of architecture.

* allow n_expert_groups to be 1 (Kimi K2)

* address review suggestions
This commit is contained in:
Sigbjørn Skjæret 2025-10-20 21:38:20 +02:00 committed by GitHub
parent c9c1972e2c
commit 84bf3c6778
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 521 additions and 10 deletions

View File

@ -138,6 +138,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32) - [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) - [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7) - [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
#### Multimodal #### Multimodal

View File

@ -892,8 +892,8 @@ class TextModel(ModelBase):
# ref: https://huggingface.co/JetBrains/Mellum-4b-base # ref: https://huggingface.co/JetBrains/Mellum-4b-base
res = "mellum" res = "mellum"
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206": if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base # ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
res = "llada-moe" res = "bailingmoe2"
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e": if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
# ref: https://huggingface.co/ibm-granite/granite-docling-258M # ref: https://huggingface.co/ibm-granite/granite-docling-258M
res = "granite-docling" res = "granite-docling"
@ -8055,6 +8055,103 @@ class BailingMoeModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}") raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("BailingMoeV2ForCausalLM")
class BailingMoeV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.BAILINGMOE2
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if nextn_layers := self.hparams.get("num_nextn_predict_layers", 0):
self.block_count = self.hparams["num_hidden_layers"] + nextn_layers
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_vocab(self):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if (rope_dim := hparams.get("head_dim")) is None:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
rope_scaling = self.hparams.get("rope_scaling") or {}
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
else:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"]))
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_count(hparams["num_experts"])
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
self.gguf_writer.add_expert_group_count(hparams["n_group"])
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
if hparams["score_function"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif hparams["score_function"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "mlp.experts" in name:
n_experts = self.hparams["num_experts"]
assert bid is not None
tensors: list[tuple[str, Tensor]] = []
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
if name.endswith(".expert_bias"):
name = name.replace(".expert_bias", ".expert_bias.bias")
return [(self.map_tensor_name(name), data_torch)]
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM") @ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
class GroveMoeModel(TextModel): class GroveMoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GROVEMOE model_arch = gguf.MODEL_ARCH.GROVEMOE

View File

@ -139,7 +139,7 @@ models = [
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"}, {"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", }, {"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", }, {"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", }, {"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
] ]

View File

@ -102,6 +102,8 @@ class Keys:
EXPERT_COUNT = "{arch}.expert_count" EXPERT_COUNT = "{arch}.expert_count"
EXPERT_USED_COUNT = "{arch}.expert_used_count" EXPERT_USED_COUNT = "{arch}.expert_used_count"
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
EXPERT_GROUP_COUNT = "{arch}.expert_group_count"
EXPERT_GROUP_USED_COUNT = "{arch}.expert_group_used_count"
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
EXPERT_GATING_FUNC = "{arch}.expert_gating_func" EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
@ -400,6 +402,7 @@ class MODEL_ARCH(IntEnum):
WAVTOKENIZER_DEC = auto() WAVTOKENIZER_DEC = auto()
PLM = auto() PLM = auto()
BAILINGMOE = auto() BAILINGMOE = auto()
BAILINGMOE2 = auto()
DOTS1 = auto() DOTS1 = auto()
ARCEE = auto() ARCEE = auto()
ERNIE4_5 = auto() ERNIE4_5 = auto()
@ -744,6 +747,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
MODEL_ARCH.PLM: "plm", MODEL_ARCH.PLM: "plm",
MODEL_ARCH.BAILINGMOE: "bailingmoe", MODEL_ARCH.BAILINGMOE: "bailingmoe",
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee", MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5", MODEL_ARCH.ERNIE4_5: "ernie4_5",
@ -2533,6 +2537,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP,
], ],
MODEL_ARCH.BAILINGMOE2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.DOTS1: [ MODEL_ARCH.DOTS1: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View File

@ -755,6 +755,12 @@ class GGUFWriter:
def add_expert_shared_count(self, count: int) -> None: def add_expert_shared_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count) self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
def add_expert_group_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count)
def add_expert_group_used_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count)
def add_expert_weights_scale(self, value: float) -> None: def add_expert_weights_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)

View File

@ -174,6 +174,7 @@ class TensorNameMap:
"h.{bid}.self_attention.query_key_value", # bloom "h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
"model.layers.{bid}.self_attn.query_key_value", # persimmon "model.layers.{bid}.self_attn.query_key_value", # persimmon
"model.layers.{bid}.attention.query_key_value", # bailingmoe2
"h.{bid}.attn.c_attn", # gpt2 "h.{bid}.attn.c_attn", # gpt2
"transformer.h.{bid}.mixer.Wqkv", # phi2 "transformer.h.{bid}.mixer.Wqkv", # phi2
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert "encoder.layers.{bid}.attn.Wqkv", # nomic-bert
@ -260,6 +261,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.out_proj", # gpt-j "transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"model.layers.{bid}.self_attn.dense", # persimmon "model.layers.{bid}.self_attn.dense", # persimmon
"model.layers.{bid}.attention.dense", # bailingmoe2
"h.{bid}.attn.c_proj", # gpt2 "h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2 "transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.layers.{bid}.self_attn.o_proj", # plamo
@ -373,6 +375,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_EXP_PROBS_B: ( MODEL_TENSOR.FFN_EXP_PROBS_B: (
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1 "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe "model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe "model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
), ),
@ -549,6 +552,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.self_attention.q_layernorm", "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon "model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan "model.layers.{bid}.self_attn.query_layernorm", # hunyuan
"model.layers.{bid}.attention.query_layernorm", # bailingmoe2
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2 "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
"layers.{bid}.self_attn.q_norm", # embeddinggemma "layers.{bid}.self_attn.q_norm", # embeddinggemma
"transformer.blocks.{bid}.attn.q_ln", # sea-lion "transformer.blocks.{bid}.attn.q_ln", # sea-lion
@ -563,6 +567,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.self_attention.k_layernorm", "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon "model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan "model.layers.{bid}.self_attn.key_layernorm", # hunyuan
"model.layers.{bid}.attention.key_layernorm", # bailingmoe2
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2 "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
"layers.{bid}.self_attn.k_norm", # embeddinggemma "layers.{bid}.self_attn.k_norm", # embeddinggemma
"transformer.blocks.{bid}.attn.k_ln", # sea-lion "transformer.blocks.{bid}.attn.k_ln", # sea-lion
@ -584,6 +589,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok "transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
"encoder.layer.{bid}.layer_norm_2", # jina-v2-code "encoder.layer.{bid}.layer_norm_2", # jina-v2-code
"model.layers.{bid}.final_layernorm", # bailingmoe2
), ),
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: ( MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (

View File

@ -85,6 +85,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
{ LLM_ARCH_PLM, "plm" }, { LLM_ARCH_PLM, "plm" },
{ LLM_ARCH_BAILINGMOE, "bailingmoe" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
{ LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" },
@ -135,6 +136,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" },
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
{ LLM_KV_EXPERT_GROUP_COUNT, "%s.expert_group_count" },
{ LLM_KV_EXPERT_GROUP_USED_COUNT, "%s.expert_group_used_count" },
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
@ -1946,6 +1949,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
}, },
}, },
{
LLM_ARCH_BAILINGMOE2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
{ LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" },
{ LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
},
},
{ {
LLM_ARCH_DOTS1, LLM_ARCH_DOTS1,
{ {

View File

@ -89,6 +89,7 @@ enum llm_arch {
LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_WAVTOKENIZER_DEC,
LLM_ARCH_PLM, LLM_ARCH_PLM,
LLM_ARCH_BAILINGMOE, LLM_ARCH_BAILINGMOE,
LLM_ARCH_BAILINGMOE2,
LLM_ARCH_DOTS1, LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE, LLM_ARCH_ARCEE,
LLM_ARCH_ERNIE4_5, LLM_ARCH_ERNIE4_5,
@ -139,6 +140,8 @@ enum llm_kv {
LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_COUNT,
LLM_KV_EXPERT_USED_COUNT, LLM_KV_EXPERT_USED_COUNT,
LLM_KV_EXPERT_SHARED_COUNT, LLM_KV_EXPERT_SHARED_COUNT,
LLM_KV_EXPERT_GROUP_COUNT,
LLM_KV_EXPERT_GROUP_USED_COUNT,
LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC, LLM_KV_EXPERT_GATING_FUNC,

View File

@ -63,6 +63,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX },
{ "bailing", LLM_CHAT_TEMPLATE_BAILING }, { "bailing", LLM_CHAT_TEMPLATE_BAILING },
{ "bailing-think", LLM_CHAT_TEMPLATE_BAILING_THINK },
{ "bailing2", LLM_CHAT_TEMPLATE_BAILING2 },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
@ -191,6 +193,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_YANDEX; return LLM_CHAT_TEMPLATE_YANDEX;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) { } else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
return LLM_CHAT_TEMPLATE_BAILING; return LLM_CHAT_TEMPLATE_BAILING;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("\"HUMAN\"") && tmpl_contains("<think>")) {
return LLM_CHAT_TEMPLATE_BAILING_THINK;
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("<role>HUMAN</role>") && tmpl_contains("<|role_end|>")) {
return LLM_CHAT_TEMPLATE_BAILING2;
} else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
return LLM_CHAT_TEMPLATE_LLAMA4; return LLM_CHAT_TEMPLATE_LLAMA4;
} else if (tmpl_contains("<|endofuserprompt|>")) { } else if (tmpl_contains("<|endofuserprompt|>")) {
@ -644,8 +650,8 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << " Ассистент:[SEP]"; ss << " Ассистент:[SEP]";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) { } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING || tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
// Bailing (Ling) template // Bailing (Ling/Ring) template
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);
@ -658,6 +664,33 @@ int32_t llm_chat_apply_template(
ss << "<role>" << role << "</role>" << message->content; ss << "<role>" << role << "</role>" << message->content;
} }
if (add_ass) {
ss << "<role>ASSISTANT</role>";
if (tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) {
ss << "<think>";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING2) {
// Bailing2 (Ling 2.0) template
bool has_system = !chat.empty() && std::string(chat[0]->role) == "system";
if (!has_system) {
ss << "<role>SYSTEM</role>detailed thinking off<|role_end|>";
}
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
role = "HUMAN";
} else {
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
}
ss << "<role>" << role << "</role>" << message->content << "<|role_end|>";
}
if (add_ass) { if (add_ass) {
ss << "<role>ASSISTANT</role>"; ss << "<role>ASSISTANT</role>";
} }

View File

@ -42,6 +42,8 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_MEGREZ,
LLM_CHAT_TEMPLATE_YANDEX, LLM_CHAT_TEMPLATE_YANDEX,
LLM_CHAT_TEMPLATE_BAILING, LLM_CHAT_TEMPLATE_BAILING,
LLM_CHAT_TEMPLATE_BAILING_THINK,
LLM_CHAT_TEMPLATE_BAILING2,
LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_LLAMA4,
LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_DOTS1,

View File

@ -950,6 +950,31 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(selection_probs, "ffn_moe_probs_biased", il); cb(selection_probs, "ffn_moe_probs_biased", il);
} }
// select top n_group_used expert groups
// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
if (hparams.n_expert_groups > 1 && n_tokens > 0) {
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
// organize experts into n_expert_groups
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups, n_tokens]
group_scores = ggml_get_rows(ctx0, ggml_reshape_4d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1], selection_groups->ne[2]), group_scores); // [1, 2, n_expert_groups, n_tokens]
// get top n_group_used expert groups
group_scores = ggml_sum_rows(ctx0, ggml_reshape_3d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2], group_scores->ne[3])); // [1, n_expert_groups, n_tokens]
group_scores = ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]); // [n_expert_groups, n_tokens]
ggml_tensor * expert_groups = ggml_top_k(ctx0, group_scores, hparams.n_group_used); // [n_group_used, n_tokens]
cb(expert_groups, "ffn_moe_group_topk", il);
// mask out the other groups
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
cb(selection_probs, "ffn_moe_probs_masked", il);
}
// select experts // select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens] ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts->src[0], "ffn_moe_argsort", il);
@ -981,6 +1006,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il); cb(weights_sum, "ffn_moe_weights_sum", il);
if (arch == LLM_ARCH_BAILINGMOE2) {
weights_sum = ggml_scale_bias(ctx0, weights_sum, 1.0, 1e-20);
cb(weights_sum, "ffn_moe_weights_sum_biased", il);
}
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_norm", il); cb(weights, "ffn_moe_weights_norm", il);

View File

@ -72,6 +72,8 @@ struct llama_hparams {
uint32_t n_ff_chexp = 0; uint32_t n_ff_chexp = 0;
uint32_t n_expert_shared = 0; uint32_t n_expert_shared = 0;
uint32_t n_norm_groups = 0; uint32_t n_norm_groups = 0;
uint32_t n_expert_groups = 0;
uint32_t n_group_used = 0;
uint32_t n_group_experts = 0; uint32_t n_group_experts = 0;
float expert_group_scale = 0.05f; float expert_group_scale = 0.05f;

View File

@ -116,8 +116,10 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_A13B: return "A13B";
case LLM_TYPE_7B_A1B: return "7B.A1B"; case LLM_TYPE_7B_A1B: return "7B.A1B";
case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_8B_A1B: return "8B.A1B";
case LLM_TYPE_16B_A1B: return "16B.A1B";
case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_21B_A3B: return "21B.A3B";
case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_100B_A6B: return "100B.A6B";
case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B";
case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_235B_A22B: return "235B.A22B";
case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_300B_A47B: return "300B.A47B";
@ -481,11 +483,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
return; return;
} }
ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false);
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
@ -501,8 +505,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
if (hparams.n_expert > 0) { if (hparams.n_expert > 0) {
GGML_ASSERT(hparams.n_expert_used > 0); GGML_ASSERT(hparams.n_expert_used > 0);
GGML_ASSERT(hparams.n_expert_groups < hparams.n_expert);
if (hparams.n_expert_groups > 1) {
GGML_ASSERT(hparams.n_expert % hparams.n_expert_groups == 0);
GGML_ASSERT(hparams.n_group_used > 0);
GGML_ASSERT(hparams.n_group_used < hparams.n_expert_groups);
}
} else { } else {
GGML_ASSERT(hparams.n_expert_used == 0); GGML_ASSERT(hparams.n_expert_used == 0);
GGML_ASSERT(hparams.n_expert_groups == 0);
} }
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
@ -1888,6 +1899,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_BAILINGMOE2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func);
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
switch (hparams.n_layer) {
case 20: type = LLM_TYPE_16B_A1B; break;
case 21: type = LLM_TYPE_16B_A1B; break;
case 32: type = LLM_TYPE_100B_A6B; break;
case 33: type = LLM_TYPE_100B_A6B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_DOTS1: case LLM_ARCH_DOTS1:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -5498,6 +5532,70 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
} }
} break; } break;
case LLM_ARCH_BAILINGMOE2:
{
const int64_t n_ff_exp = hparams.n_ff_exp;
const int64_t n_expert_shared = hparams.n_expert_shared;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2");
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2");
for (int i = 0; i < n_layer; ++i) {
int flags = 0;
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
// skip all tensors in the NextN layers
flags |= TENSOR_SKIP;
}
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags);
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers
const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared;
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags);
} else { // Dense layers
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags);
}
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags);
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags);
}
}
} break;
case LLM_ARCH_DOTS1: case LLM_ARCH_DOTS1:
{ {
const int64_t n_ff_exp = hparams.n_ff_exp; const int64_t n_ff_exp = hparams.n_ff_exp;
@ -6353,6 +6451,19 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
} }
if (arch == LLM_ARCH_BAILINGMOE2) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers);
}
if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
@ -17042,6 +17153,150 @@ struct llm_build_bailingmoe : public llm_graph_context {
} }
}; };
struct llm_build_bailingmoe2 : public llm_graph_context {
llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self_attention
{
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd));
ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd));
ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa));
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpSA);
cb(sa_out, "sa_out", il);
// MoE branch
cur = build_norm(sa_out,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
ggml_tensor * moe_out =
build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(moe_out, "ffn_moe_out", il);
{
ggml_tensor * ffn_shexp = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
}
}
cur = ggml_add(ctx0, cur, sa_out);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_dots1 : public llm_graph_context { struct llm_build_dots1 : public llm_graph_context {
llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
@ -19838,6 +20093,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{ {
llm = std::make_unique<llm_build_bailingmoe>(*this, params); llm = std::make_unique<llm_build_bailingmoe>(*this, params);
} break; } break;
case LLM_ARCH_BAILINGMOE2:
{
llm = std::make_unique<llm_build_bailingmoe2>(*this, params);
} break;
case LLM_ARCH_SEED_OSS: case LLM_ARCH_SEED_OSS:
{ {
llm = std::make_unique<llm_build_seed_oss>(*this, params); llm = std::make_unique<llm_build_seed_oss>(*this, params);
@ -20104,6 +20363,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_EXAONE: case LLM_ARCH_EXAONE:
case LLM_ARCH_EXAONE4: case LLM_ARCH_EXAONE4:
case LLM_ARCH_MINICPM3: case LLM_ARCH_MINICPM3:
case LLM_ARCH_BAILINGMOE2:
case LLM_ARCH_DOTS1: case LLM_ARCH_DOTS1:
case LLM_ARCH_HUNYUAN_MOE: case LLM_ARCH_HUNYUAN_MOE:
case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_OPENAI_MOE:

View File

@ -109,8 +109,10 @@ enum llm_type {
LLM_TYPE_A13B, LLM_TYPE_A13B,
LLM_TYPE_7B_A1B, LLM_TYPE_7B_A1B,
LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_8B_A1B, // lfm2moe
LLM_TYPE_16B_A1B,
LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_21B_A3B, // Ernie MoE small
LLM_TYPE_30B_A3B, LLM_TYPE_30B_A3B,
LLM_TYPE_100B_A6B,
LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_106B_A12B, // GLM-4.5-Air
LLM_TYPE_235B_A22B, LLM_TYPE_235B_A22B,
LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_300B_A47B, // Ernie MoE big

View File

@ -1968,6 +1968,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
clean_spaces = false; clean_spaces = false;
} else if ( } else if (
tokenizer_pre == "bailingmoe" || tokenizer_pre == "bailingmoe" ||
tokenizer_pre == "bailingmoe2" ||
tokenizer_pre == "llada-moe") { tokenizer_pre == "llada-moe") {
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
clean_spaces = false; clean_spaces = false;