add glm4-moe mtp cgraph

This commit is contained in:
Xuan Son Nguyen 2026-01-16 17:08:12 +01:00
parent 0802d4cfb3
commit 3d4b6c7fd2
5 changed files with 259 additions and 127 deletions

View File

@ -281,6 +281,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_cross_mtp::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
if (cross_mtp && !cross->mtp_embd.empty()) {
assert(cross_mtp->type == GGML_TYPE_F32);
assert(ggml_nelements(cross_mtp) == (int64_t)cross->mtp_embd.size());
ggml_backend_tensor_set(cross_mtp, cross->mtp_embd.data(), 0, ggml_nbytes(cross_mtp));
}
}
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
const char * swa_type_str = "unknown";
@ -1419,6 +1430,20 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
return cur;
}
ggml_tensor * llm_graph_context::build_inp_cross_mtp() const {
auto inp = std::make_unique<llm_graph_input_cross_mtp>(hparams.n_pos_per_embd());
auto & cur = inp->cross_mtp;
GGML_ASSERT(cross != nullptr);
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, cross->n_embd, cross->n_token);
ggml_set_input(cur);
res->add_input(std::move(inp));
return cur;
}
ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);

View File

@ -55,18 +55,24 @@ enum llm_norm_type {
};
// TODO: tmp - need something better to pass the data from the encoder to the decoder
// currently also for passing embeddings for from main model to MTP layers
struct llama_cross {
// the output embeddings from the encoder as a ggml tensor
// TODO: this needs more work to be correct, for now copy the embeddings data to host memory
// ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
//ggml_tensor * t_embd = nullptr;
int64_t n_embd = 0;
int64_t n_enc = 0;
int64_t n_embd = 0;
int64_t n_enc = 0;
int64_t n_token = 0; // used by mtp
// embeddings data copied to host memory (tmp)
std::vector<float> v_embd;
// embeddings data to be passed to MTP layers
// TODO: optimize by using ggml_tensor here
std::vector<float> mtp_embd;
// needed to construct the cross-attention mask in the decoder
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};
@ -255,6 +261,18 @@ public:
const llama_cross * cross;
};
class llm_graph_input_cross_mtp : public llm_graph_input_i {
public:
llm_graph_input_cross_mtp(
const llama_cross * cross) : cross(cross) {}
virtual ~llm_graph_input_cross_mtp() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * cross_mtp; // F32 [n_embd, n_token]
const llama_cross * cross;
};
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
public:
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
@ -426,6 +444,7 @@ class llm_graph_result;
struct llm_graph_params {
llm_arch arch = LLM_ARCH_UNKNOWN;
bool is_mtp = false;
llama_hparams hparams;
llama_cparams cparams;
@ -756,6 +775,7 @@ struct llm_graph_context {
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_cross_mtp() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
ggml_tensor * build_inp_pos_bucket_dec() const;
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;

View File

@ -7871,7 +7871,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break;
case LLM_ARCH_GLM4_MOE:
{
llm = std::make_unique<llm_build_glm4_moe>(*this, params);
if (params.is_mtp) {
llm = std::make_unique<llm_build_glm4_moe<true>>(*this, params);
} else {
llm = std::make_unique<llm_build_glm4_moe<false>>(*this, params);
}
} break;
case LLM_ARCH_BITNET:
{

View File

@ -1,7 +1,9 @@
#include "models.h"
llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
template <>
llm_build_glm4_moe<false>::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const bool use_mrope = hparams.use_mrope();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@ -13,7 +15,6 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap
inpL = build_inp_embd(model.tok_embd);
bool use_mrope = hparams.use_mrope();
if (ubatch.embd && !use_mrope) {
// unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results
GGML_ABORT("This GGUF does not support multimodal. Please reconvert it.");
@ -30,129 +31,9 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap
// Final layer tensors are loaded but not processed in forward pass
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
for (int il = 0; il < n_transformer_layers; ++il) {
ggml_tensor * inpSA = inpL;
// Pre-attention norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
}
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
}
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
}
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// Apply Q/K norm if available (GLM-4.5 355B variant)
if (model.layers[il].attn_q_norm) {
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
}
if (model.layers[il].attn_k_norm) {
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
}
if (use_mrope) {
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
} else {
// Normal RoPE
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot,
rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot,
rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
if (il == n_transformer_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// Post-attention norm
cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "post_attn_norm", il);
// Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense)
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
// Dense FFN layer
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
// Process routed experts using existing MoE infrastructure
ggml_tensor * routed_out = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(routed_out, "ffn_moe_out", il);
// Process shared expert on original input
ggml_tensor * shared_out = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(shared_out, "ffn_shexp_out", il);
// Final output: routed_output + shared_output
cur = ggml_add(ctx0, routed_out, shared_out);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
bool is_output_layer = (il == n_transformer_layers - 1);
inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il);
}
cur = inpL;
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
@ -168,3 +49,196 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap
ggml_build_forward_expand(gf, cur);
}
// MTP model
template <>
llm_build_glm4_moe<true>::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const bool use_mrope = hparams.use_mrope();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
int sections[4];
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
ggml_tensor * cur;
ggml_tensor * inpL;
// for now, we only support one single NextN layer for simplicity
GGML_ASSERT(hparams.nextn_predict_layers == 1);
const int il = n_layer - hparams.nextn_predict_layers;
auto & mtp_layer = model.layers[il];
ggml_tensor * inp_token_embd = build_inp_embd(mtp_layer.nextn.embed_tokens // can be nullptr on GLM-4.6
? mtp_layer.nextn.embed_tokens : model.tok_embd);
ggml_tensor * inp_state_embd = build_inp_cross_mtp();
inp_token_embd = build_norm(inp_token_embd, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
inp_state_embd = build_norm(inp_state_embd, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
inpL = ggml_concat(ctx0, inp_token_embd, inp_state_embd, 0);
cb(inpL, "inp_mtp", il);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
{
// input for next layer
bool is_output_layer = (il == n_layer - 1);
inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il);
}
cur = inpL;
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm // can be nullptr on GLM-4.6
? mtp_layer.nextn.shared_head_norm : model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(mtp_layer.nextn.shared_head_head // can be nullptr on GLM-4.6
? mtp_layer.nextn.shared_head_head : model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
template <bool is_mtp>
ggml_tensor * llm_build_glm4_moe<is_mtp>::build_layer(const llama_model & model,
llm_graph_input_attn_kv * inp_attn,
ggml_tensor * inpL,
ggml_tensor * inp_pos,
ggml_tensor * inp_out_ids,
int sections[4],
bool is_output_layer,
int il) {
bool use_mrope = hparams.use_mrope();
const int64_t n_embd_head = hparams.n_embd_head_v;
ggml_tensor * inpSA = inpL;
// Pre-attention norm
ggml_tensor * cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
{
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
}
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
}
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
}
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// Apply Q/K norm if available (GLM-4.5 355B variant)
if (model.layers[il].attn_q_norm) {
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
}
if (model.layers[il].attn_k_norm) {
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
}
if (use_mrope) {
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
} else {
// Normal RoPE
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot,
rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot,
rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
if (is_output_layer && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// Post-attention norm
cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "post_attn_norm", il);
// Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense)
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
// Dense FFN layer
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
// Process routed experts using existing MoE infrastructure
ggml_tensor * routed_out = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(routed_out, "ffn_moe_out", il);
// Process shared expert on original input
ggml_tensor * shared_out = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(shared_out, "ffn_shexp_out", il);
// Final output: routed_output + shared_output
cur = ggml_add(ctx0, routed_out, shared_out);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
return cur;
}

View File

@ -222,8 +222,17 @@ struct llm_build_glm4 : public llm_graph_context {
llm_build_glm4(const llama_model & model, const llm_graph_params & params);
};
template <bool is_mtp>
struct llm_build_glm4_moe : public llm_graph_context {
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params);
ggml_tensor * build_layer(const llama_model & model,
llm_graph_input_attn_kv * inp_attn,
ggml_tensor * inpL,
ggml_tensor * inp_pos,
ggml_tensor * inp_out_ids,
int sections[4],
bool is_output_layer,
int il);
};
struct llm_build_gpt2 : public llm_graph_context {