glm4: add MTP weight fallback for GLM-4.6 compatibility

GLM-4.6 models exclude specific MTP tensors (`embed_tokens` and `shared_head_head`), implying weight tying with the main model. Previously, this caused a crash when building the graph.

This commit adds a fallback mechanism to use the main model's token embeddings and output head when the MTP-specific tensors are missing.
This commit is contained in:
samuel 2025-12-10 22:54:27 -03:00 committed by Aaron Lee
parent 38c91187f9
commit d9576dd037
2 changed files with 19 additions and 5 deletions

View File

@ -32,7 +32,8 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap
const int il_mtp = hparams.n_layer - 1;
const auto & mtp_layer = model.layers[il_mtp];
res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head);
res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head, model);
} else {
ggml_tensor * inpL;
@ -196,7 +197,8 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap
}
ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, int64_t n_embd_head) {
ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings,
int64_t n_embd_head, const llama_model & model) {
ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings);
cb(embd_copy, "mtp_embd_copy", -1);
@ -204,7 +206,13 @@ ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer,
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens);
// If nextn.embed_tokens is missing (GLM-4.6), use model.tok_embd
ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens;
if (mtp_embd_weights == nullptr) {
mtp_embd_weights = model.tok_embd;
}
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights);
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
@ -304,7 +312,13 @@ ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer,
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "mtp_ffn_out_resid", il);
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
// If nextn.shared_head_head is missing (GLM-4.6), use model.output (Main LM Head)
ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head;
if (mtp_head_weights == nullptr) {
mtp_head_weights = model.output;
}
cur = build_lora_mm(mtp_head_weights, cur);
return cur;
}

View File

@ -221,7 +221,7 @@ struct llm_build_glm4 : public llm_graph_context {
struct llm_build_glm4_moe : public llm_graph_context {
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params);
ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, int64_t n_embd_head);
ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, int64_t n_embd_head, const llama_model & model);
};
struct llm_build_gpt2 : public llm_graph_context {