From de17303dc63ad478edf7314961aa3777a8f2cf14 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 17 Jan 2026 20:54:09 +0100 Subject: [PATCH] llama_mtp_start --- include/llama.h | 13 +++++-------- src/llama-context.cpp | 43 +++++++++++++++++++++++++++-------------- src/llama-context.h | 6 +++++- src/models/glm4-moe.cpp | 4 +++- 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/include/llama.h b/include/llama.h index 71537b3d53..7e9518bdae 100644 --- a/include/llama.h +++ b/include/llama.h @@ -370,6 +370,7 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + bool is_mtp; // create context for Multi-Token Prediction (MTP) // [EXPERIMENTAL] // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) @@ -939,14 +940,6 @@ extern "C" { struct llama_context * ctx, struct llama_batch batch); - // Process a batch of tokens using MTP (Multi-Token Prediction). - // The input token can be either from the last llama_decode() call, - // or from the previous llama_decode_mtp() call. - // Input token order must match the output token order from the previous call. - LLAMA_API int32_t llama_decode_mtp( - struct llama_context * ctx, - struct llama_batch batch); - // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) @@ -1014,6 +1007,10 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // Copy the internal MTP state from ctx_llm to ctx_mtp, ready for MTP decoding. + // This must be done before calling llama_decode() on ctx_mtp + LLAMA_API int32_t llama_mtp_start(struct llama_context * ctx_llm, struct llama_context * ctx_mtp); + // // backend sampling API [EXPERIMENTAL] // note: use only if the llama_context was created with at least one llama_sampler_seq_config diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d9a00ee748..a7bb2882ae 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -27,6 +27,8 @@ llama_context::llama_context( // may need to be backend-dependent LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); + is_mtp = params.is_mtp; + t_start_us = model.t_start_us; t_load_us = model.t_load_us; @@ -814,6 +816,23 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +int32_t llama_context::cpy_mtp_state(llama_context & ctx_mtp) { + if (!ctx_mtp.is_mtp) { + LLAMA_LOG_ERROR("%s: target context is not MTP\n", __func__); + return -1; + } + + if (cross.n_token == 0 || cross.n_embd == 0) { + LLAMA_LOG_ERROR("%s: no state to copy\n", __func__); + return -1; + } + + // TODO: maybe std::move is better? + ctx_mtp.cross = cross; + + return 0; +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1456,7 +1475,7 @@ static void copy_tensor_async_candidates( } } -int llama_context::decode(const llama_batch & batch_inp, bool is_mtp) { +int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT if (is_mtp) { @@ -1464,7 +1483,7 @@ int llama_context::decode(const llama_batch & batch_inp, bool is_mtp) { LLAMA_LOG_ERROR("%s: MTP decode called but model does not support MTP\n", __func__); return -1; } - if (batch_inp.n_tokens > n_ubatch()) { + if ((uint32_t)batch_inp.n_tokens > n_ubatch()) { // TODO @ngxson : n_tokens > ubatch will mess up the llama_cross state, may need to fix it later LLAMA_LOG_ERROR("%s: MTP decode requires n_ubatch >= n_tokens\n", __func__); return -1; @@ -3043,6 +3062,7 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.is_mtp =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, }; @@ -3233,6 +3253,12 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +int32_t llama_mtp_start(llama_context * ctx_llm, llama_context * ctx_mtp) { + ctx_llm->synchronize(); + + return ctx_llm->cpy_mtp_state(*ctx_mtp); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } @@ -3553,7 +3579,7 @@ int32_t llama_encode( int32_t llama_decode( llama_context * ctx, llama_batch batch) { - const int ret = ctx->decode(batch, false); + const int ret = ctx->decode(batch); if (ret != 0 && ret != 1) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -3561,17 +3587,6 @@ int32_t llama_decode( return ret; } -int32_t llama_decode_mtp( - llama_context * ctx, - llama_batch batch) { - const int ret = ctx->decode(batch, true); - if (ret != 0 && ret != 1) { - LLAMA_LOG_ERROR("%s: failed to decode MTP, ret = %d\n", __func__, ret); - } - - return ret; -} - // // perf // diff --git a/src/llama-context.h b/src/llama-context.h index aaf58a533c..265c339d62 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -77,6 +77,8 @@ struct llama_context { float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + + int32_t cpy_mtp_state(llama_context & ctx_mtp); llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -131,7 +133,7 @@ struct llama_context { ggml_status & ret); int encode(const llama_batch & batch_inp); - int decode(const llama_batch & batch_inp, bool is_mtp); + int decode(const llama_batch & batch_inp); // // state save/load @@ -349,6 +351,8 @@ private: // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; + bool is_mtp = false; + bool has_evaluated_once = false; // env: LLAMA_GRAPH_REUSE_DISABLE diff --git a/src/models/glm4-moe.cpp b/src/models/glm4-moe.cpp index 6e81167671..8a87c63faf 100644 --- a/src/models/glm4-moe.cpp +++ b/src/models/glm4-moe.cpp @@ -54,7 +54,6 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const l template <> llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - const bool use_mrope = hparams.use_mrope(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -73,6 +72,9 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const ll ? mtp_layer.nextn.embed_tokens : model.tok_embd); ggml_tensor * inp_state_embd = build_inp_cross_mtp(); + // check number of input tokens + GGML_ASSERT(inp_state_embd->ne[1] == inp_token_embd->ne[1]); + inp_token_embd = build_norm(inp_token_embd, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); inp_state_embd = build_norm(inp_state_embd, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);