From 3d4b6c7fd24707c8f8028b91b72957963f3ed7d3 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 16 Jan 2026 17:08:12 +0100 Subject: [PATCH 1/3] add glm4-moe mtp cgraph --- src/llama-graph.cpp | 25 ++++ src/llama-graph.h | 24 ++- src/llama-model.cpp | 6 +- src/models/glm4-moe.cpp | 322 ++++++++++++++++++++++++---------------- src/models/models.h | 9 ++ 5 files changed, 259 insertions(+), 127 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 944c7e53bd..901027c0ef 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -281,6 +281,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } +void llm_graph_input_cross_mtp::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (cross_mtp && !cross->mtp_embd.empty()) { + assert(cross_mtp->type == GGML_TYPE_F32); + assert(ggml_nelements(cross_mtp) == (int64_t)cross->mtp_embd.size()); + + ggml_backend_tensor_set(cross_mtp, cross->mtp_embd.data(), 0, ggml_nbytes(cross_mtp)); + } +} + static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); const char * swa_type_str = "unknown"; @@ -1419,6 +1430,20 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { return cur; } +ggml_tensor * llm_graph_context::build_inp_cross_mtp() const { + auto inp = std::make_unique(hparams.n_pos_per_embd()); + + auto & cur = inp->cross_mtp; + + GGML_ASSERT(cross != nullptr); + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, cross->n_embd, cross->n_token); + ggml_set_input(cur); + + res->add_input(std::move(inp)); + + return cur; +} + ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { auto inp = std::make_unique(hparams); diff --git a/src/llama-graph.h b/src/llama-graph.h index 503ffd695a..55c83830b3 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -55,18 +55,24 @@ enum llm_norm_type { }; // TODO: tmp - need something better to pass the data from the encoder to the decoder +// currently also for passing embeddings for from main model to MTP layers struct llama_cross { // the output embeddings from the encoder as a ggml tensor // TODO: this needs more work to be correct, for now copy the embeddings data to host memory // ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524 //ggml_tensor * t_embd = nullptr; - int64_t n_embd = 0; - int64_t n_enc = 0; + int64_t n_embd = 0; + int64_t n_enc = 0; + int64_t n_token = 0; // used by mtp // embeddings data copied to host memory (tmp) std::vector v_embd; + // embeddings data to be passed to MTP layers + // TODO: optimize by using ggml_tensor here + std::vector mtp_embd; + // needed to construct the cross-attention mask in the decoder std::vector> seq_ids_enc; }; @@ -255,6 +261,18 @@ public: const llama_cross * cross; }; +class llm_graph_input_cross_mtp : public llm_graph_input_i { +public: + llm_graph_input_cross_mtp( + const llama_cross * cross) : cross(cross) {} + virtual ~llm_graph_input_cross_mtp() = default; + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * cross_mtp; // F32 [n_embd, n_token] + + const llama_cross * cross; +}; + class llm_graph_input_attn_no_cache : public llm_graph_input_i { public: llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) : @@ -426,6 +444,7 @@ class llm_graph_result; struct llm_graph_params { llm_arch arch = LLM_ARCH_UNKNOWN; + bool is_mtp = false; llama_hparams hparams; llama_cparams cparams; @@ -756,6 +775,7 @@ struct llm_graph_context { ggml_tensor * build_inp_cls() const; ggml_tensor * build_inp_cross_embd() const; + ggml_tensor * build_inp_cross_mtp() const; ggml_tensor * build_inp_pos_bucket_enc() const; ggml_tensor * build_inp_pos_bucket_dec() const; ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 94c47dc248..8da41ebeeb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7871,7 +7871,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GLM4_MOE: { - llm = std::make_unique(*this, params); + if (params.is_mtp) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_BITNET: { diff --git a/src/models/glm4-moe.cpp b/src/models/glm4-moe.cpp index 003f70f739..6e81167671 100644 --- a/src/models/glm4-moe.cpp +++ b/src/models/glm4-moe.cpp @@ -1,7 +1,9 @@ #include "models.h" -llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template <> +llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; + const bool use_mrope = hparams.use_mrope(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13,7 +15,6 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap inpL = build_inp_embd(model.tok_embd); - bool use_mrope = hparams.use_mrope(); if (ubatch.embd && !use_mrope) { // unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results GGML_ABORT("This GGUF does not support multimodal. Please reconvert it."); @@ -30,129 +31,9 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap // Final layer tensors are loaded but not processed in forward pass const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; for (int il = 0; il < n_transformer_layers; ++il) { - ggml_tensor * inpSA = inpL; - - // Pre-attention norm - cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - // Apply Q/K norm if available (GLM-4.5 355B variant) - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - } - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); - } - - if (use_mrope) { - Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - } else { - // Normal RoPE - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, - rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, - rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - } - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - } - if (il == n_transformer_layers - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // Post-attention norm - cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "post_attn_norm", il); - - // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) - if (static_cast(il) < hparams.n_layer_dense_lead) { - // Dense FFN layer - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } else { - // Process routed experts using existing MoE infrastructure - ggml_tensor * routed_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, - (llama_expert_gating_func_type) hparams.expert_gating_func, - il); - cb(routed_out, "ffn_moe_out", il); - - // Process shared expert on original input - ggml_tensor * shared_out = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(shared_out, "ffn_shexp_out", il); - - // Final output: routed_output + shared_output - cur = ggml_add(ctx0, routed_out, shared_out); - cb(cur, "ffn_out", il); - } - cur = ggml_add(ctx0, cur, ffn_inp); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - // input for next layer - inpL = cur; + bool is_output_layer = (il == n_transformer_layers - 1); + inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il); } cur = inpL; cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); @@ -168,3 +49,196 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap ggml_build_forward_expand(gf, cur); } + +// MTP model +template <> +llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const bool use_mrope = hparams.use_mrope(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + // for now, we only support one single NextN layer for simplicity + GGML_ASSERT(hparams.nextn_predict_layers == 1); + const int il = n_layer - hparams.nextn_predict_layers; + auto & mtp_layer = model.layers[il]; + + ggml_tensor * inp_token_embd = build_inp_embd(mtp_layer.nextn.embed_tokens // can be nullptr on GLM-4.6 + ? mtp_layer.nextn.embed_tokens : model.tok_embd); + ggml_tensor * inp_state_embd = build_inp_cross_mtp(); + + inp_token_embd = build_norm(inp_token_embd, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + inp_state_embd = build_norm(inp_state_embd, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + + inpL = ggml_concat(ctx0, inp_token_embd, inp_state_embd, 0); + cb(inpL, "inp_mtp", il); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + { + // input for next layer + bool is_output_layer = (il == n_layer - 1); + inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il); + } + cur = inpL; + cur = build_norm(cur, mtp_layer.nextn.shared_head_norm // can be nullptr on GLM-4.6 + ? mtp_layer.nextn.shared_head_norm : model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(mtp_layer.nextn.shared_head_head // can be nullptr on GLM-4.6 + ? mtp_layer.nextn.shared_head_head : model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +template +ggml_tensor * llm_build_glm4_moe::build_layer(const llama_model & model, + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * inpL, + ggml_tensor * inp_pos, + ggml_tensor * inp_out_ids, + int sections[4], + bool is_output_layer, + int il) { + bool use_mrope = hparams.use_mrope(); + const int64_t n_embd_head = hparams.n_embd_head_v; + + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + ggml_tensor * cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + if (use_mrope) { + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } else { + // Normal RoPE + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (is_output_layer && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); + + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + return cur; +} diff --git a/src/models/models.h b/src/models/models.h index 3a44f7f140..346f779fad 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -222,8 +222,17 @@ struct llm_build_glm4 : public llm_graph_context { llm_build_glm4(const llama_model & model, const llm_graph_params & params); }; +template struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params); + ggml_tensor * build_layer(const llama_model & model, + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * inpL, + ggml_tensor * inp_pos, + ggml_tensor * inp_out_ids, + int sections[4], + bool is_output_layer, + int il); }; struct llm_build_gpt2 : public llm_graph_context { From b4457e48bb23f15babff30d5ecf81d7816c4256c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 16 Jan 2026 23:38:54 +0100 Subject: [PATCH 2/3] wip --- include/llama.h | 8 ++++++++ src/llama-context.cpp | 47 ++++++++++++++++++++++++++++++++++++++++--- src/llama-context.h | 2 +- src/llama-graph.cpp | 2 +- src/llama-graph.h | 2 +- src/llama-hparams.cpp | 4 ++++ src/llama-hparams.h | 3 +++ src/llama-model.cpp | 2 +- 8 files changed, 63 insertions(+), 7 deletions(-) diff --git a/include/llama.h b/include/llama.h index 280745713e..71537b3d53 100644 --- a/include/llama.h +++ b/include/llama.h @@ -939,6 +939,14 @@ extern "C" { struct llama_context * ctx, struct llama_batch batch); + // Process a batch of tokens using MTP (Multi-Token Prediction). + // The input token can be either from the last llama_decode() call, + // or from the previous llama_decode_mtp() call. + // Input token order must match the output token order from the previous call. + LLAMA_API int32_t llama_decode_mtp( + struct llama_context * ctx, + struct llama_batch batch); + // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a6d5ddfa33..d9a00ee748 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1456,9 +1456,21 @@ static void copy_tensor_async_candidates( } } -int llama_context::decode(const llama_batch & batch_inp) { +int llama_context::decode(const llama_batch & batch_inp, bool is_mtp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + if (is_mtp) { + if (model.hparams.nextn_predict_layers == 0) { + LLAMA_LOG_ERROR("%s: MTP decode called but model does not support MTP\n", __func__); + return -1; + } + if (batch_inp.n_tokens > n_ubatch()) { + // TODO @ngxson : n_tokens > ubatch will mess up the llama_cross state, may need to fix it later + LLAMA_LOG_ERROR("%s: MTP decode requires n_ubatch >= n_tokens\n", __func__); + return -1; + } + } + if (!memory) { LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); return encode(batch_inp); @@ -1587,6 +1599,15 @@ int llama_context::decode(const llama_batch & batch_inp) { break; } + const bool update_mtp_state = hparams.nextn_predict_layers > 0 && n_outputs > 0; + + // set MTP state if needed + if (update_mtp_state) { + cross.n_embd = hparams.get_n_embd_mtp(); + cross.n_token = n_outputs; + cross.mtp_embd.resize(cross.n_embd*cross.n_token); + } + // reserve output buffer if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); @@ -1615,7 +1636,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + llm_graph_type gtype = is_mtp ? LLM_GRAPH_TYPE_DECODER_MTP : LLM_GRAPH_TYPE_DECODER; + const auto * res = process_ubatch(ubatch, gtype, mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module @@ -1683,6 +1705,14 @@ int llama_context::decode(const llama_batch & batch_inp) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); + // set MTP state if needed + if (update_mtp_state) { + const int64_t n_embd_mtp = cross.n_embd; + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_mtp <= (int64_t)cross.mtp_embd.size()); + ggml_backend_tensor_get_async(backend_embd, t_embd, cross.mtp_embd.data(), 0, n_outputs*n_embd_mtp*sizeof(float)); + } + switch (cparams.pooling_type) { case LLAMA_POOLING_TYPE_NONE: { @@ -3523,7 +3553,7 @@ int32_t llama_encode( int32_t llama_decode( llama_context * ctx, llama_batch batch) { - const int ret = ctx->decode(batch); + const int ret = ctx->decode(batch, false); if (ret != 0 && ret != 1) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -3531,6 +3561,17 @@ int32_t llama_decode( return ret; } +int32_t llama_decode_mtp( + llama_context * ctx, + llama_batch batch) { + const int ret = ctx->decode(batch, true); + if (ret != 0 && ret != 1) { + LLAMA_LOG_ERROR("%s: failed to decode MTP, ret = %d\n", __func__, ret); + } + + return ret; +} + // // perf // diff --git a/src/llama-context.h b/src/llama-context.h index 86decc05fb..aaf58a533c 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -131,7 +131,7 @@ struct llama_context { ggml_status & ret); int encode(const llama_batch & batch_inp); - int decode(const llama_batch & batch_inp); + int decode(const llama_batch & batch_inp, bool is_mtp); // // state save/load diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 901027c0ef..a6dd1525d1 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1431,7 +1431,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { } ggml_tensor * llm_graph_context::build_inp_cross_mtp() const { - auto inp = std::make_unique(hparams.n_pos_per_embd()); + auto inp = std::make_unique(cross); auto & cur = inp->cross_mtp; diff --git a/src/llama-graph.h b/src/llama-graph.h index 55c83830b3..adbd335137 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -30,6 +30,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DECODER_MTP, }; enum llm_ffn_op_type { @@ -444,7 +445,6 @@ class llm_graph_result; struct llm_graph_params { llm_arch arch = LLM_ARCH_UNKNOWN; - bool is_mtp = false; llama_hparams hparams; llama_cparams cparams; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index c847ef91b7..29dbd08f32 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -76,6 +76,10 @@ uint32_t llama_hparams::get_n_embd_out() const { return n_embd_out > 0 ? n_embd_out : n_embd; } +uint32_t llama_hparams::get_n_embd_mtp() const { + return n_embd; +} + uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 7ae3ec292e..588ce2987b 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -240,6 +240,9 @@ struct llama_hparams { // dimension of output embeddings uint32_t get_n_embd_out() const; + // dimension of cross embeddings between main LLM and MTP + uint32_t get_n_embd_mtp() const; + // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t il = 0) const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8da41ebeeb..f1194afe4b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7871,7 +7871,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GLM4_MOE: { - if (params.is_mtp) { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { llm = std::make_unique>(*this, params); } else { llm = std::make_unique>(*this, params); From de17303dc63ad478edf7314961aa3777a8f2cf14 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 17 Jan 2026 20:54:09 +0100 Subject: [PATCH 3/3] llama_mtp_start --- include/llama.h | 13 +++++-------- src/llama-context.cpp | 43 +++++++++++++++++++++++++++-------------- src/llama-context.h | 6 +++++- src/models/glm4-moe.cpp | 4 +++- 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/include/llama.h b/include/llama.h index 71537b3d53..7e9518bdae 100644 --- a/include/llama.h +++ b/include/llama.h @@ -370,6 +370,7 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + bool is_mtp; // create context for Multi-Token Prediction (MTP) // [EXPERIMENTAL] // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) @@ -939,14 +940,6 @@ extern "C" { struct llama_context * ctx, struct llama_batch batch); - // Process a batch of tokens using MTP (Multi-Token Prediction). - // The input token can be either from the last llama_decode() call, - // or from the previous llama_decode_mtp() call. - // Input token order must match the output token order from the previous call. - LLAMA_API int32_t llama_decode_mtp( - struct llama_context * ctx, - struct llama_batch batch); - // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) @@ -1014,6 +1007,10 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // Copy the internal MTP state from ctx_llm to ctx_mtp, ready for MTP decoding. + // This must be done before calling llama_decode() on ctx_mtp + LLAMA_API int32_t llama_mtp_start(struct llama_context * ctx_llm, struct llama_context * ctx_mtp); + // // backend sampling API [EXPERIMENTAL] // note: use only if the llama_context was created with at least one llama_sampler_seq_config diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d9a00ee748..a7bb2882ae 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -27,6 +27,8 @@ llama_context::llama_context( // may need to be backend-dependent LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); + is_mtp = params.is_mtp; + t_start_us = model.t_start_us; t_load_us = model.t_load_us; @@ -814,6 +816,23 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +int32_t llama_context::cpy_mtp_state(llama_context & ctx_mtp) { + if (!ctx_mtp.is_mtp) { + LLAMA_LOG_ERROR("%s: target context is not MTP\n", __func__); + return -1; + } + + if (cross.n_token == 0 || cross.n_embd == 0) { + LLAMA_LOG_ERROR("%s: no state to copy\n", __func__); + return -1; + } + + // TODO: maybe std::move is better? + ctx_mtp.cross = cross; + + return 0; +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1456,7 +1475,7 @@ static void copy_tensor_async_candidates( } } -int llama_context::decode(const llama_batch & batch_inp, bool is_mtp) { +int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT if (is_mtp) { @@ -1464,7 +1483,7 @@ int llama_context::decode(const llama_batch & batch_inp, bool is_mtp) { LLAMA_LOG_ERROR("%s: MTP decode called but model does not support MTP\n", __func__); return -1; } - if (batch_inp.n_tokens > n_ubatch()) { + if ((uint32_t)batch_inp.n_tokens > n_ubatch()) { // TODO @ngxson : n_tokens > ubatch will mess up the llama_cross state, may need to fix it later LLAMA_LOG_ERROR("%s: MTP decode requires n_ubatch >= n_tokens\n", __func__); return -1; @@ -3043,6 +3062,7 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.is_mtp =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, }; @@ -3233,6 +3253,12 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +int32_t llama_mtp_start(llama_context * ctx_llm, llama_context * ctx_mtp) { + ctx_llm->synchronize(); + + return ctx_llm->cpy_mtp_state(*ctx_mtp); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } @@ -3553,7 +3579,7 @@ int32_t llama_encode( int32_t llama_decode( llama_context * ctx, llama_batch batch) { - const int ret = ctx->decode(batch, false); + const int ret = ctx->decode(batch); if (ret != 0 && ret != 1) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -3561,17 +3587,6 @@ int32_t llama_decode( return ret; } -int32_t llama_decode_mtp( - llama_context * ctx, - llama_batch batch) { - const int ret = ctx->decode(batch, true); - if (ret != 0 && ret != 1) { - LLAMA_LOG_ERROR("%s: failed to decode MTP, ret = %d\n", __func__, ret); - } - - return ret; -} - // // perf // diff --git a/src/llama-context.h b/src/llama-context.h index aaf58a533c..265c339d62 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -77,6 +77,8 @@ struct llama_context { float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + + int32_t cpy_mtp_state(llama_context & ctx_mtp); llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -131,7 +133,7 @@ struct llama_context { ggml_status & ret); int encode(const llama_batch & batch_inp); - int decode(const llama_batch & batch_inp, bool is_mtp); + int decode(const llama_batch & batch_inp); // // state save/load @@ -349,6 +351,8 @@ private: // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; + bool is_mtp = false; + bool has_evaluated_once = false; // env: LLAMA_GRAPH_REUSE_DISABLE diff --git a/src/models/glm4-moe.cpp b/src/models/glm4-moe.cpp index 6e81167671..8a87c63faf 100644 --- a/src/models/glm4-moe.cpp +++ b/src/models/glm4-moe.cpp @@ -54,7 +54,6 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const l template <> llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - const bool use_mrope = hparams.use_mrope(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -73,6 +72,9 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const ll ? mtp_layer.nextn.embed_tokens : model.tok_embd); ggml_tensor * inp_state_embd = build_inp_cross_mtp(); + // check number of input tokens + GGML_ASSERT(inp_state_embd->ne[1] == inp_token_embd->ne[1]); + inp_token_embd = build_norm(inp_token_embd, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); inp_state_embd = build_norm(inp_state_embd, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);