model : add AfmoeForCausalLM support (#16477)
* Add AFMOE model support * Update to vocab * Add model sizing * Undo Rope change for ARCEE model * Address review comments * Update modeling code is_sliding -> use_rope, replace hard-coded logic * Fix AFMOE tokenizer * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update AFMoE tokenizer class identification to be more unique --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
6cd0cf72ce
commit
e1fcf8b09b
|
|
@ -1124,6 +1124,9 @@ class TextModel(ModelBase):
|
|||
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
|
||||
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
|
||||
res = "mellum"
|
||||
if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
|
||||
# ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
|
||||
res = "afmoe"
|
||||
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
|
||||
res = "bailingmoe2"
|
||||
|
|
@ -2533,6 +2536,81 @@ class ArceeModel(LlamaModel):
|
|||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
|
||||
|
||||
@ModelBase.register("AfmoeForCausalLM")
|
||||
class AfmoeModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.AFMOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# MoE parameters
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
|
||||
self.gguf_writer.add_expert_shared_count(n_shared_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
if (n_dense_layers := self.hparams.get("num_dense_layers")) is not None:
|
||||
self.gguf_writer.add_leading_dense_block_count(n_dense_layers)
|
||||
|
||||
# Expert Gating Function
|
||||
score_func = self.hparams.get("score_func")
|
||||
if score_func == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif score_func == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
elif score_func is not None:
|
||||
raise ValueError(f"Unsupported score_function value: {score_func}")
|
||||
|
||||
# Route normalization and scaling
|
||||
if (route_norm := self.hparams.get("route_norm")) is not None:
|
||||
self.gguf_writer.add_expert_weights_norm(route_norm)
|
||||
if (route_scale := self.hparams.get("route_scale")) is not None:
|
||||
self.gguf_writer.add_expert_weights_scale(route_scale)
|
||||
|
||||
# Sliding window attention
|
||||
if (sliding_window := self.hparams.get("sliding_window")) is not None:
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Handle expert weights - they're already merged in the HF format
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["gate_proj", "up_proj", "down_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename_to_retrieve])
|
||||
del self._experts[bid][ename_to_retrieve]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
if name.endswith(".expert_bias"):
|
||||
name = name.replace(".expert_bias", ".expert_bias.bias")
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"LlavaForConditionalGeneration", # pixtral
|
||||
"Mistral3ForConditionalGeneration", # mistral small 3.1
|
||||
|
|
|
|||
|
|
@ -139,6 +139,7 @@ models = [
|
|||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
|
||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
|
||||
|
|
|
|||
|
|
@ -409,6 +409,7 @@ class MODEL_ARCH(IntEnum):
|
|||
BAILINGMOE2 = auto()
|
||||
DOTS1 = auto()
|
||||
ARCEE = auto()
|
||||
AFMOE = auto()
|
||||
ERNIE4_5 = auto()
|
||||
ERNIE4_5_MOE = auto()
|
||||
HUNYUAN_MOE = auto()
|
||||
|
|
@ -464,6 +465,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
ATTN_POST_NORM = auto()
|
||||
ATTN_ROT_EMBD = auto()
|
||||
ATTN_SINKS = auto()
|
||||
ATTN_GATE = auto()
|
||||
FFN_GATE_INP = auto()
|
||||
FFN_GATE_INP_SHEXP = auto()
|
||||
FFN_NORM = auto()
|
||||
|
|
@ -776,6 +778,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
|
||||
MODEL_ARCH.DOTS1: "dots1",
|
||||
MODEL_ARCH.ARCEE: "arcee",
|
||||
MODEL_ARCH.AFMOE: "afmoe",
|
||||
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
||||
MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe",
|
||||
MODEL_ARCH.FALCON_H1: "falcon-h1",
|
||||
|
|
@ -828,6 +831,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
||||
MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
|
||||
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
|
||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
|
||||
|
|
@ -2693,6 +2697,33 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.AFMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_GATE,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
],
|
||||
MODEL_ARCH.ERNIE4_5: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
|||
|
|
@ -314,6 +314,10 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.self_attn.sinks", # openai-moe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_GATE: (
|
||||
"model.layers.{bid}.self_attn.gate_proj", # afmoe
|
||||
),
|
||||
|
||||
# Feed-forward norm
|
||||
MODEL_TENSOR.FFN_NORM: (
|
||||
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||
|
|
@ -340,11 +344,12 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.feedforward_layernorm", # apertus
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
# Pre feed-forward norm
|
||||
MODEL_TENSOR.FFN_PRE_NORM: (
|
||||
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
|
||||
"layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
|
||||
"model.layers.{bid}.pre_ff_layernorm.weight",
|
||||
"model.layers.{bid}.pre_mlp_layernorm", # afmoe
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
|
|
@ -370,6 +375,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.gate.wg", # hunyuan
|
||||
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
|
||||
"model.layers.{bid}.feed_forward.gate", # lfm2moe
|
||||
"model.layers.{bid}.mlp.router.gate", # afmoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
|
|
@ -380,6 +386,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
|
||||
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
|
||||
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
|
||||
"model.layers.{bid}.mlp.expert_bias", # afmoe
|
||||
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
|
||||
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
|
||||
),
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ add_library(llama
|
|||
unicode-data.cpp
|
||||
unicode.cpp
|
||||
unicode.h
|
||||
models/afmoe.cpp
|
||||
models/apertus.cpp
|
||||
models/arcee.cpp
|
||||
models/arctic.cpp
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
|
||||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_ARCEE, "arcee" },
|
||||
{ LLM_ARCH_AFMOE, "afmoe" },
|
||||
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
||||
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
|
|
@ -333,6 +334,36 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_AFMOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLAMA4,
|
||||
{
|
||||
|
|
@ -2444,6 +2475,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
|
|
|||
|
|
@ -94,6 +94,7 @@ enum llm_arch {
|
|||
LLM_ARCH_BAILINGMOE2,
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_AFMOE,
|
||||
LLM_ARCH_ERNIE4_5,
|
||||
LLM_ARCH_ERNIE4_5_MOE,
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
|
|
@ -312,6 +313,7 @@ enum llm_tensor {
|
|||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||
LLM_TENSOR_ATTN_SINKS,
|
||||
LLM_TENSOR_ATTN_GATE,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ const char * llm_type_name(llm_type type) {
|
|||
case LLM_TYPE_15B: return "15B";
|
||||
case LLM_TYPE_16B: return "16B";
|
||||
case LLM_TYPE_20B: return "20B";
|
||||
case LLM_TYPE_26B: return "26B";
|
||||
case LLM_TYPE_27B: return "27B";
|
||||
case LLM_TYPE_30B: return "30B";
|
||||
case LLM_TYPE_32B: return "32B";
|
||||
|
|
@ -695,6 +696,37 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
|
||||
// Set up interleaved sliding window attention (ISWA)
|
||||
// Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4)
|
||||
if (hparams.n_swa > 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.set_swa_pattern(4);
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
}
|
||||
|
||||
// Default to sigmoid if not set
|
||||
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 56: type = LLM_TYPE_6B; break;
|
||||
case 32: type = LLM_TYPE_26B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DECI:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
|
@ -5749,6 +5781,71 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
const int64_t n_expert_shared = hparams.n_expert_shared;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
// dual attention normalization
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// attention projections
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
// Q/K normalization
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
// attention gating
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
|
||||
// dual ffn normalization
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) {
|
||||
// MoE layers
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
|
||||
|
||||
// grouped expert weights
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||
|
||||
// shared expert
|
||||
if (n_expert_shared > 0) {
|
||||
const int64_t n_ff_shexp = n_ff_exp * n_expert_shared;
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
|
||||
}
|
||||
} else {
|
||||
// Dense layers
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
case LLM_ARCH_ERNIE4_5_MOE:
|
||||
{
|
||||
|
|
@ -7243,6 +7340,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_arcee>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_AFMOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_afmoe>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
{
|
||||
llm = std::make_unique<llm_build_ernie4_5>(*this, params);
|
||||
|
|
@ -7528,6 +7629,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_MINIMAX_M2:
|
||||
case LLM_ARCH_COGVLM:
|
||||
case LLM_ARCH_PANGU_EMBED:
|
||||
case LLM_ARCH_AFMOE:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ enum llm_type {
|
|||
LLM_TYPE_15B,
|
||||
LLM_TYPE_16B,
|
||||
LLM_TYPE_20B,
|
||||
LLM_TYPE_26B,
|
||||
LLM_TYPE_27B,
|
||||
LLM_TYPE_30B,
|
||||
LLM_TYPE_32B,
|
||||
|
|
@ -234,6 +235,7 @@ struct llama_layer {
|
|||
struct ggml_tensor * wk_enc = nullptr;
|
||||
struct ggml_tensor * wv_enc = nullptr;
|
||||
struct ggml_tensor * wo_enc = nullptr;
|
||||
struct ggml_tensor * wqkv_gate = nullptr;
|
||||
|
||||
// attention bias
|
||||
struct ggml_tensor * bq = nullptr;
|
||||
|
|
|
|||
|
|
@ -443,6 +443,17 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_AFMOE:
|
||||
regex_exprs = {
|
||||
// Digit handling - uses custom implementation in unicode.cpp
|
||||
// Groups digits with leading 1-2 based on total length modulo 3
|
||||
"\\p{AFMoE_digits}",
|
||||
// CJK and Asian scripts (using direct Unicode literals)
|
||||
"[一-鿿㐀-䶿豈--ゟ゠-ヿ・-゚⼀-เ--ក-က-႟ꩠ-ꩿꧠ-가-ᄀ-ᇿ]+",
|
||||
// Main BPE pattern
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
|
|
@ -1993,6 +2004,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
tokenizer_pre == "grok-2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "afmoe") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_AFMOE;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "minimax-m2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2;
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ enum llama_vocab_pre_type {
|
|||
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
|
||||
LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41,
|
||||
LLAMA_VOCAB_PRE_TYPE_AFMOE = 42,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,187 @@
|
|||
#include "models.h"
|
||||
|
||||
llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// MuP scaling: embeddings * sqrt(hidden_size)
|
||||
// mup_enabled = true, hidden_size = 1024, scale = 32.0
|
||||
inpL = ggml_scale(ctx0, inpL, sqrtf(float(n_embd)));
|
||||
cb(inpL, "inp_embd_scaled", -1);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// dual attention normalization (pre)
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * attn_inp = cur; // save input for gate computation
|
||||
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// compute gate from input
|
||||
ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp);
|
||||
cb(gate, "attn_gate_proj", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Q/K normalization
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
// RoPE only for sliding_attention layers
|
||||
const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
|
||||
((il + 1) % hparams.n_no_rope_layer_step) != 0;
|
||||
if (use_rope) {
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur_rope", il);
|
||||
}
|
||||
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
NULL, NULL, // wo will be applied after gating
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
|
||||
// attention gating: attn_out * sigmoid(gate) BEFORE o_proj
|
||||
gate = ggml_sigmoid(ctx0, gate);
|
||||
cb(gate, "attn_gate_sig", il);
|
||||
cur = ggml_mul(ctx0, cur, gate);
|
||||
cb(cur, "attn_gated", il);
|
||||
|
||||
// now apply output projection
|
||||
cur = build_lora_mm(model.layers[il].wo, cur);
|
||||
cb(cur, "attn_o_proj", il);
|
||||
}
|
||||
|
||||
// dual attention normalization (post)
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].attn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// dual ffn normalization (pre)
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// MoE or dense FFN
|
||||
if ((uint32_t)il >= hparams.n_layer_dense_lead) {
|
||||
// MoE layer with sigmoid routing, normalization, and scaling
|
||||
ggml_tensor * moe_out = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU,
|
||||
hparams.expert_weights_norm, // norm_w (route_norm=True)
|
||||
hparams.expert_weights_scale, // scale_w
|
||||
hparams.expert_weights_scale, // w_scale (route_scale=2.826)
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// shared expert
|
||||
if (hparams.n_expert_shared > 0) {
|
||||
ggml_tensor * ffn_shexp = build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(ffn_shexp, "ffn_shexp", il);
|
||||
|
||||
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
cur = moe_out;
|
||||
}
|
||||
} else {
|
||||
// dense layer
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
// dual ffn normalization (post)
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].ffn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_post_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
|
@ -57,6 +57,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
int il) const;
|
||||
};
|
||||
|
||||
struct llm_build_afmoe : public llm_graph_context {
|
||||
llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_apertus : public llm_graph_context {
|
||||
llm_build_apertus(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -729,6 +729,80 @@ static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string
|
|||
return bpe_offsets;
|
||||
}
|
||||
|
||||
// AFMOE digit handling: splits digits with leading 1-2 based on total length modulo 3
|
||||
static std::vector<size_t> unicode_regex_split_custom_afmoe(const std::string & text, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
bpe_offsets.reserve(offsets.size());
|
||||
|
||||
const auto cpts = unicode_cpts_from_utf8(text);
|
||||
|
||||
size_t start = 0;
|
||||
for (auto offset : offsets) {
|
||||
const size_t offset_ini = start;
|
||||
const size_t offset_end = start + offset;
|
||||
assert(offset_end <= cpts.size());
|
||||
start = offset_end;
|
||||
|
||||
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
|
||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
|
||||
};
|
||||
|
||||
size_t _prev_end = offset_ini;
|
||||
auto _add_token = [&] (const size_t end) -> size_t {
|
||||
assert(_prev_end <= end && end <= offset_end);
|
||||
size_t len = end - _prev_end;
|
||||
if (len > 0) {
|
||||
bpe_offsets.push_back(len);
|
||||
}
|
||||
_prev_end = end;
|
||||
return len;
|
||||
};
|
||||
|
||||
for (size_t pos = offset_ini; pos < offset_end; ) {
|
||||
const auto flags = _get_flags(pos);
|
||||
|
||||
// Handle digit sequences with special splitting logic
|
||||
if (flags.is_number) {
|
||||
size_t digit_start = pos;
|
||||
size_t digit_count = 0;
|
||||
|
||||
// Count consecutive digits
|
||||
while (_get_flags(pos).is_number && pos < offset_end) {
|
||||
digit_count++;
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Split based on total length modulo 3
|
||||
size_t remainder = digit_count % 3;
|
||||
size_t current = digit_start;
|
||||
|
||||
// Emit leading 1-2 digits if needed
|
||||
if (remainder > 0) {
|
||||
_add_token(current + remainder);
|
||||
current += remainder;
|
||||
}
|
||||
|
||||
// Emit groups of 3
|
||||
while (current < digit_start + digit_count) {
|
||||
_add_token(current + 3);
|
||||
current += 3;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// For non-digits, just move forward
|
||||
pos++;
|
||||
}
|
||||
|
||||
// Add any remaining content
|
||||
if (_prev_end < offset_end) {
|
||||
_add_token(offset_end);
|
||||
}
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
}
|
||||
|
||||
static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
|
||||
std::vector<size_t> bpe_offsets;
|
||||
|
||||
|
|
@ -742,6 +816,9 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
|
|||
} else if (regex_expr == "\\p{Han}+") {
|
||||
// K2's first pattern - handle all K2 patterns together
|
||||
bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
|
||||
} else if (regex_expr == "\\p{AFMoE_digits}") {
|
||||
// AFMOE digit pattern - use custom implementation for proper splitting
|
||||
bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
|
|
|
|||
Loading…
Reference in New Issue