model : Minimax M2 (#16831)

* Model: Minimax M2

* Cleanup

* Cleanup pt. 2

* Cleanup pt. 3

* Update convert_hf_to_gguf_update.py - merge catch blocks

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Remove vocab models and test

* Remove all redundant hparam settings covered by TextModel

* Move super to start, don't set block_count

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update gguf-py/gguf/constants.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Piotr Wilkin (ilintar) 2025-10-31 21:20:47 +01:00 committed by GitHub
parent e58d585604
commit 0de0a01576
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 284 additions and 1 deletions

View File

@ -1054,6 +1054,9 @@ class TextModel(ModelBase):
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
res = "granite-docling"
if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95":
# ref: https://huggingface.co/MiniMaxAI/MiniMax-M2
res = "minimax-m2"
if res is None:
logger.warning("\n")
@ -7126,6 +7129,64 @@ class DeepseekV2Model(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("MiniMaxM2ForCausalLM")
class MiniMaxM2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MINIMAXM2
_experts_cache: dict[int, dict[str, Tensor]] = {}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["num_experts"] = self.hparams["num_local_experts"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.hparams["scoring_func"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif self.hparams["scoring_func"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
self.gguf_writer.add_expert_feed_forward_length(self.find_hparam(["intermediate_size"]))
self.gguf_writer.add_rope_dimension_count(self.find_hparam(["rotary_dim"]))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
# merge expert weights
if 'experts' in name:
n_experts = self.hparams["num_experts"]
assert bid is not None
expert_cache = self._experts_cache.setdefault(bid, {})
expert_cache[name] = data_torch
expert_weights = ["w1", "w2", "w3"]
# not enough expert weights to merge
if len(expert_cache) < n_experts * len(expert_weights):
return []
tensors: list[tuple[str, Tensor]] = []
for w_name in expert_weights:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
datas.append(expert_cache[ename])
del expert_cache[ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
del self._experts_cache[bid]
return tensors
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Dots1ForCausalLM")
class Dots1Model(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.DOTS1

View File

@ -141,6 +141,7 @@ models = [
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
@ -435,7 +436,7 @@ for model in models:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
else:
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
except OSError as e:
except (OSError, TypeError) as e:
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
continue # Skip this model and continue with the next one in the loop

View File

@ -425,6 +425,7 @@ class MODEL_ARCH(IntEnum):
GROVEMOE = auto()
APERTUS = auto()
COGVLM = auto()
MINIMAXM2 = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@ -790,6 +791,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.SEED_OSS: "seed_oss",
MODEL_ARCH.GROVEMOE: "grovemoe",
MODEL_ARCH.APERTUS: "apertus",
MODEL_ARCH.MINIMAXM2: "minimax-m2",
MODEL_ARCH.COGVLM: "cogvlm",
}
@ -2921,6 +2923,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_CHEXP,
MODEL_TENSOR.FFN_UP_CHEXP,
],
MODEL_ARCH.MINIMAXM2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.COGVLM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View File

@ -381,6 +381,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
"model.layers.{bid}.block_sparse_moe.e_score_correction", # minimax-m2
),
# Feed-forward up

View File

@ -105,6 +105,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_SEED_OSS, "seed_oss" },
{ LLM_ARCH_GROVEMOE, "grovemoe" },
{ LLM_ARCH_APERTUS, "apertus" },
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
{ LLM_ARCH_COGVLM, "cogvlm" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -2355,6 +2356,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" },
},
},
{
LLM_ARCH_MINIMAX_M2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_COGVLM,
{

View File

@ -109,6 +109,7 @@ enum llm_arch {
LLM_ARCH_SEED_OSS,
LLM_ARCH_GROVEMOE,
LLM_ARCH_APERTUS,
LLM_ARCH_MINIMAX_M2,
LLM_ARCH_COGVLM,
LLM_ARCH_UNKNOWN,
};

View File

@ -120,6 +120,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_100B_A6B: return "100B.A6B";
case LLM_TYPE_106B_A12B: return "106B.A12B";
case LLM_TYPE_230B_A10B: return "230B.A10B";
case LLM_TYPE_235B_A22B: return "235B.A22B";
case LLM_TYPE_300B_A47B: return "300B.A47B";
case LLM_TYPE_355B_A32B: return "355B.A32B";
@ -2155,6 +2156,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_MINIMAX_M2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
switch (hparams.n_layer) {
case 62: type = LLM_TYPE_230B_A10B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_COGVLM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -6185,6 +6197,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED);
}
} break;
case LLM_ARCH_MINIMAX_M2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k * n_head}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_k_gqa}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
}
} break;
case LLM_ARCH_COGVLM:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -20024,6 +20065,130 @@ struct llm_build_apertus : public llm_graph_context {
}
};
struct llm_build_minimax_m2 : public llm_graph_context {
llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
// GGML_ASSERT(n_embd_head == hparams.n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * inp_pos = build_inp_pos();
auto inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
cur = inpL;
// self_attention
{
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL,
LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL,
LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// MoE branch
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(cur, "ffn_moe_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_cogvlm : public llm_graph_context {
llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
@ -20654,6 +20819,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_apertus>(*this, params);
} break;
case LLM_ARCH_MINIMAX_M2:
{
llm = std::make_unique<llm_build_minimax_m2>(*this, params);
} break;
case LLM_ARCH_COGVLM:
{
llm = std::make_unique<llm_build_cogvlm>(*this, params);
@ -20875,6 +21044,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_SEED_OSS:
case LLM_ARCH_GROVEMOE:
case LLM_ARCH_APERTUS:
case LLM_ARCH_MINIMAX_M2:
case LLM_ARCH_COGVLM:
return LLAMA_ROPE_TYPE_NEOX;

View File

@ -114,6 +114,7 @@ enum llm_type {
LLM_TYPE_30B_A3B,
LLM_TYPE_100B_A6B,
LLM_TYPE_106B_A12B, // GLM-4.5-Air
LLM_TYPE_230B_A10B, // Minimax M2
LLM_TYPE_235B_A22B,
LLM_TYPE_300B_A47B, // Ernie MoE big
LLM_TYPE_355B_A32B, // GLM-4.5

View File

@ -401,6 +401,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
};
break;
case LLAMA_VOCAB_PRE_TYPE_GPT4O:
case LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2:
regex_exprs = {
// original regex from tokenizer.json
// "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
@ -1992,6 +1993,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "grok-2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
clean_spaces = false;
} else if (
tokenizer_pre == "minimax-m2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2;
clean_spaces = false;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}

View File

@ -49,6 +49,7 @@ enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41,
};
struct LLM_KV;