From b4457e48bb23f15babff30d5ecf81d7816c4256c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 16 Jan 2026 23:38:54 +0100 Subject: [PATCH] wip --- include/llama.h | 8 ++++++++ src/llama-context.cpp | 47 ++++++++++++++++++++++++++++++++++++++++--- src/llama-context.h | 2 +- src/llama-graph.cpp | 2 +- src/llama-graph.h | 2 +- src/llama-hparams.cpp | 4 ++++ src/llama-hparams.h | 3 +++ src/llama-model.cpp | 2 +- 8 files changed, 63 insertions(+), 7 deletions(-) diff --git a/include/llama.h b/include/llama.h index 280745713e..71537b3d53 100644 --- a/include/llama.h +++ b/include/llama.h @@ -939,6 +939,14 @@ extern "C" { struct llama_context * ctx, struct llama_batch batch); + // Process a batch of tokens using MTP (Multi-Token Prediction). + // The input token can be either from the last llama_decode() call, + // or from the previous llama_decode_mtp() call. + // Input token order must match the output token order from the previous call. + LLAMA_API int32_t llama_decode_mtp( + struct llama_context * ctx, + struct llama_batch batch); + // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a6d5ddfa33..d9a00ee748 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1456,9 +1456,21 @@ static void copy_tensor_async_candidates( } } -int llama_context::decode(const llama_batch & batch_inp) { +int llama_context::decode(const llama_batch & batch_inp, bool is_mtp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + if (is_mtp) { + if (model.hparams.nextn_predict_layers == 0) { + LLAMA_LOG_ERROR("%s: MTP decode called but model does not support MTP\n", __func__); + return -1; + } + if (batch_inp.n_tokens > n_ubatch()) { + // TODO @ngxson : n_tokens > ubatch will mess up the llama_cross state, may need to fix it later + LLAMA_LOG_ERROR("%s: MTP decode requires n_ubatch >= n_tokens\n", __func__); + return -1; + } + } + if (!memory) { LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); return encode(batch_inp); @@ -1587,6 +1599,15 @@ int llama_context::decode(const llama_batch & batch_inp) { break; } + const bool update_mtp_state = hparams.nextn_predict_layers > 0 && n_outputs > 0; + + // set MTP state if needed + if (update_mtp_state) { + cross.n_embd = hparams.get_n_embd_mtp(); + cross.n_token = n_outputs; + cross.mtp_embd.resize(cross.n_embd*cross.n_token); + } + // reserve output buffer if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); @@ -1615,7 +1636,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + llm_graph_type gtype = is_mtp ? LLM_GRAPH_TYPE_DECODER_MTP : LLM_GRAPH_TYPE_DECODER; + const auto * res = process_ubatch(ubatch, gtype, mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module @@ -1683,6 +1705,14 @@ int llama_context::decode(const llama_batch & batch_inp) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); + // set MTP state if needed + if (update_mtp_state) { + const int64_t n_embd_mtp = cross.n_embd; + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_mtp <= (int64_t)cross.mtp_embd.size()); + ggml_backend_tensor_get_async(backend_embd, t_embd, cross.mtp_embd.data(), 0, n_outputs*n_embd_mtp*sizeof(float)); + } + switch (cparams.pooling_type) { case LLAMA_POOLING_TYPE_NONE: { @@ -3523,7 +3553,7 @@ int32_t llama_encode( int32_t llama_decode( llama_context * ctx, llama_batch batch) { - const int ret = ctx->decode(batch); + const int ret = ctx->decode(batch, false); if (ret != 0 && ret != 1) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -3531,6 +3561,17 @@ int32_t llama_decode( return ret; } +int32_t llama_decode_mtp( + llama_context * ctx, + llama_batch batch) { + const int ret = ctx->decode(batch, true); + if (ret != 0 && ret != 1) { + LLAMA_LOG_ERROR("%s: failed to decode MTP, ret = %d\n", __func__, ret); + } + + return ret; +} + // // perf // diff --git a/src/llama-context.h b/src/llama-context.h index 86decc05fb..aaf58a533c 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -131,7 +131,7 @@ struct llama_context { ggml_status & ret); int encode(const llama_batch & batch_inp); - int decode(const llama_batch & batch_inp); + int decode(const llama_batch & batch_inp, bool is_mtp); // // state save/load diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 901027c0ef..a6dd1525d1 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1431,7 +1431,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { } ggml_tensor * llm_graph_context::build_inp_cross_mtp() const { - auto inp = std::make_unique(hparams.n_pos_per_embd()); + auto inp = std::make_unique(cross); auto & cur = inp->cross_mtp; diff --git a/src/llama-graph.h b/src/llama-graph.h index 55c83830b3..adbd335137 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -30,6 +30,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DECODER_MTP, }; enum llm_ffn_op_type { @@ -444,7 +445,6 @@ class llm_graph_result; struct llm_graph_params { llm_arch arch = LLM_ARCH_UNKNOWN; - bool is_mtp = false; llama_hparams hparams; llama_cparams cparams; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index c847ef91b7..29dbd08f32 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -76,6 +76,10 @@ uint32_t llama_hparams::get_n_embd_out() const { return n_embd_out > 0 ? n_embd_out : n_embd; } +uint32_t llama_hparams::get_n_embd_mtp() const { + return n_embd; +} + uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 7ae3ec292e..588ce2987b 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -240,6 +240,9 @@ struct llama_hparams { // dimension of output embeddings uint32_t get_n_embd_out() const; + // dimension of cross embeddings between main LLM and MTP + uint32_t get_n_embd_mtp() const; + // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t il = 0) const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8da41ebeeb..f1194afe4b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7871,7 +7871,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GLM4_MOE: { - if (params.is_mtp) { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { llm = std::make_unique>(*this, params); } else { llm = std::make_unique>(*this, params);