From fe2baf5e2d685e05e88176fe4e6550787b9c2fbd Mon Sep 17 00:00:00 2001 From: samuel Date: Tue, 9 Dec 2025 22:50:04 -0300 Subject: [PATCH 1/6] Squashed commit of the following: commit 912ed2cd9339d1b2875d98744ca5b51fa62e581e Author: samuel Date: Sun Dec 7 23:00:29 2025 -0300 speculative (feat): implement recursive MTP drafting for GLM-4.5 commit bdf72d9552e3da64ffc85f175664713388752914 Author: samuel Date: Sat Dec 6 16:10:16 2025 -0300 sampling (feat): optimize speculative drafting with fast-path selection commit a91980a8f3475a6bbac0a64d8be06dd4b613020e Author: samuel Date: Sat Dec 6 15:18:19 2025 -0300 mtp (chore): clean old code commit 6de0ecf55db8567db4faa99b0152b72c9e854548 Author: samuel Date: Sat Dec 6 14:40:13 2025 -0300 mtp (feat): add mtp arg commit ea77394183b8e6c368af969b8274039a54b11486 Author: samuel Date: Sat Dec 6 13:47:54 2025 -0300 mtp-graph (fix): move llama_get_logits_ith outside the loop commit 15dff208958fb66802f20ec53ce5fcaff133edb7 Merge: 171346c74 cae85fe53 Author: samuel Date: Thu Oct 16 13:44:41 2025 -0300 Merge branch 'glm4-mtp-batch' of https://github.com/SamuelOliveirads/llama.cpp into glm4-mtp-graph-cache commit cae85fe531876762ee02524fc4c3f6c5e7824c63 Author: samuel Date: Thu Oct 16 13:42:31 2025 -0300 mtp-batch(fix): avoid logits for mtp kv cache operations commit 171346c742c310bbcfbd786b61250638ccf8b44d Author: samuel Date: Sun Oct 12 16:33:01 2025 -0300 mtp-graph(feat): Reactivate graph reuse only for main model path commit 0127c6beeb384ec3abbc18b22dbe830f22fcf4b4 Author: samuel Date: Sat Oct 11 22:20:54 2025 -0300 mtp-batch(chore): Remove final MTP debug logs and dead code commit 4bcc9e261ef57ee4cfaa65d06bcd0fcdeacf7797 Author: samuel Date: Sat Oct 11 18:51:22 2025 -0300 mtp-batch(fix): Correctly advance cache head and add MTP documentation commit b4cbe030ac25056717763b812d1dd89681c08522 Author: samuel Date: Sat Oct 11 18:37:40 2025 -0300 mtp-batch(chore): Fix logit flags for speculative sampling and remove debug logs commit a99709d0c1401d0b447dce1bd0101fb56390f50e Author: samuel Date: Fri Oct 10 17:24:34 2025 -0300 mtp-batch(refactor): Extract decode context and MTP input logic into helper methods commit 913af8f48d2dab1d9e907cf6c48c921a229a295c Author: samuel Date: Fri Oct 10 16:44:28 2025 -0300 mtp-batch(refactor): Replace MTP boolean flags with an explicit operation enum commit 6f74ba38070d62d37bc0fb71ce9871e1a4ffabcc Author: samuel Date: Thu Oct 9 22:27:18 2025 -0300 mtp-batch (fix): prevent mtp draft from polluting the cache commit 5e1d719beffccf8c22784c24b52ff6f5ab56b9ff Author: samuel Date: Thu Oct 9 15:21:23 2025 -0300 mtp-batch (feat): Create and manage sinfo for MTP commit febd8235d27fe9174ee4b54ea7a10e630939fee0 Author: samuel Date: Sun Oct 5 14:43:40 2025 -0300 mtp-batch (wip): fix how to warmup kv cache for MTP commit 67c6c069e0a5496adfd7d8aa6ca7514db5a6f437 Author: samuel Date: Sat Sep 27 19:42:32 2025 -0300 mtp-batch (wip): Isolate MTP graph to prevent host embedding buffer corruption commit 75dc25e6fe781c1b65038d69390fb778d760e3a1 Author: samuel Date: Sat Sep 27 17:17:00 2025 -0300 mtp-batch (wip): organize batch for mtp cache commit 3da7e7f3309dbb576538850c92c1cbf8fdc6d6ee Author: samuel Date: Tue Sep 23 22:45:11 2025 -0300 mtp-batch (fix): warm mtp cache for small batch size commit df64508b937784112168aa099644b60fef015f05 Author: samuel Date: Sun Sep 21 21:55:41 2025 -0300 mtp-batch (wip): merge glm graphs commit 042eb8a829876ed175320df9c8133bcea0c40460 Author: samuel Date: Sun Sep 21 21:29:00 2025 -0300 mtp-batch (wip): merge mtp and model graph commit 1318b2de82716710b9853e07bd640443a5a025bb Author: samuel Date: Sun Sep 14 10:22:59 2025 -0300 mtp-batch (wip): move mtp execution to batch format commit c6237c71ffd4485df1c35829c380b63e472fc5dd Merge: 9fab53e43 8742ce0e3 Author: Aaron Lee Date: Sat Sep 13 02:57:01 2025 -0400 Merge pull request #1 from SamuelOliveirads/glm4-moe-mtp feat: implemented sampling for MTP commit 8742ce0e39823eeb101bb5b6099ff4ca7be10c6e Author: samuel Date: Sat Sep 6 00:21:18 2025 -0300 feat: apply logits + greedy sampler commit 5a5bce85777041d841393b4396e28f8e3065bb10 Author: samuel Date: Wed Sep 3 17:56:14 2025 -0300 fix: add sample acceptance commit 07670a22c63b1fa335d6ec1c4a1e4255a920848c Author: samuel Date: Wed Sep 3 13:25:21 2025 -0300 feat: implemented sampling for MTP commit 9fab53e4388c20aef497efd82e86dcb99ca58064 Author: Aaron Lee Date: Tue Sep 2 17:14:09 2025 -0400 fixed mtp kv cache update step in cases where prompt size > n_batch and n_ubatch commit 98bc0c6bf223f425f4ecea14f13fc46101f1b44a Author: Aaron Lee Date: Tue Aug 26 01:26:51 2025 -0400 replace standard sampler with greedy sampler for mtp draft commit 471e026327cca9f6f58aeefe32129a6cb9390f4f Author: Aaron Lee Date: Tue Aug 19 23:10:56 2025 -0400 fixed vram leak commit d72f9d5691054958cd1b139f228e5e588d3974cf Author: Aaron Lee Date: Tue Aug 19 01:50:34 2025 -0400 kludge-y kv cache management of mtp layer commit 382135aa3619294ab8bf87b0de4b1255ab7942f0 Author: Aaron Lee Date: Sun Aug 17 21:54:45 2025 -0400 fixed mtp kv cache update sequencing after prompt processing commit 6870f9790c1bb1d0254241267b1a6c8a7fc82830 Author: Aaron Lee Date: Sun Aug 17 04:59:36 2025 -0400 added proper KV cache management for MTP layers and slightly refactored commit 6e9bafc7a738b4c99f9440c0ec461e08cf6ce702 Author: Aaron Lee Date: Fri Aug 15 23:13:56 2025 -0400 failed attempt to implement MTP; outputs tokens but KV cache management is unreasonable commit cf0f7c0448c2c1736588673114558e5829db7879 Author: Aaron Lee Date: Wed Aug 13 02:21:17 2025 -0400 broad thrust of the mtp implementation commit 03231da69eec20677e25e2307d4fe31ac2ede034 Author: Aaron Lee Date: Tue Aug 12 01:03:59 2025 -0400 add model member function to build mtp graph, to be called from speculative.cpp commit 1f477b375504aa557ed21066aa6783b11781a179 Author: Aaron Lee Date: Mon Aug 11 20:54:45 2025 -0400 make nextn weights loadable without a crash commit e434f87cc739a1901931d88e33f777170a4e18e7 Author: Aaron Lee Date: Mon Aug 11 01:21:47 2025 -0400 some work towards building mtp layer graph commit db60623e7926fb151b3cc63f029929122cac342a Author: Aaron Lee Date: Sun Aug 10 23:52:54 2025 -0400 added getter for nextn layer count and server slot has_mtp property --- common/arg.cpp | 7 + common/common.h | 1 + common/sampling.cpp | 39 +++ common/speculative.cpp | 113 +++++++++ common/speculative.h | 47 +++- include/llama.h | 46 ++++ src/llama-arch.cpp | 13 +- src/llama-batch.cpp | 35 +-- src/llama-context.cpp | 309 +++++++++++++++++----- src/llama-context.h | 26 +- src/llama-graph.cpp | 20 ++ src/llama-graph.h | 18 ++ src/llama-kv-cache.cpp | 131 +++++++--- src/llama-kv-cache.h | 17 +- src/llama-model.cpp | 15 +- src/models/glm4-moe.cpp | 437 +++++++++++++++++++++----------- src/models/models.h | 2 + tools/server/server-context.cpp | 73 +++++- 18 files changed, 1053 insertions(+), 296 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 1302065498..e919f618dd 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3214,6 +3214,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.cache_type_k = kv_cache_type_from_str(value); } ).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT")); + add_opt(common_arg( + {"-mtp", "--multi-token-prediction"}, + string_format("Activate multi-token-prediction (if supported) (default: %s)", params.mtp ? "true" : "false"), + [](common_params & params) { + params.mtp = true; + } + )); add_opt(common_arg( {"-ctvd", "--cache-type-v-draft"}, "TYPE", string_format( diff --git a/common/common.h b/common/common.h index 334372073a..6c2f1dc686 100644 --- a/common/common.h +++ b/common/common.h @@ -430,6 +430,7 @@ struct common_params { bool no_op_offload = false; // globally disable offload host tensor operations to device bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) bool no_host = false; // bypass host buffer allowing extra buffers to be used + bool mtp = false; // use mtp is supported bool single_turn = false; // single turn chat conversation diff --git a/common/sampling.cpp b/common/sampling.cpp index c66f935c65..3254f8d66c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -666,3 +666,42 @@ std::vector common_sampler_types_from_chars(const std::stri return samplers; } + +/** + * Specialized sampling for speculative drafting. + * + * Prioritizes performance by using a direct ArgMax loop (Greedy) when no + * penalties (repetition, frequency, presence, DRY) are configured. + * Falls back to the full sampler chain if penalties are active to prevent + * generative loops or adhere to constraints. + */ +llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) { + const auto & params = gsmpl->params; + + bool use_heavy_sampler = + (params.penalty_last_n > 0 && ( + params.penalty_repeat != 1.0f || + params.penalty_freq != 0.0f || + params.penalty_present != 0.0f + )) || + (params.dry_allowed_length > 0 && params.dry_multiplier != 0.0f); + + if (use_heavy_sampler) { + return common_sampler_sample(gsmpl, ctx, idx, false); + } + + float * logits = llama_get_logits_ith(ctx, idx); + const int n_vocab = llama_n_vocab(llama_model_get_vocab(llama_get_model(ctx))); + + int best_id = 0; + float max_val = logits[0]; + + for (int i = 1; i < n_vocab; ++i) { + if (logits[i] > max_val) { + max_val = logits[i]; + best_id = i; + } + } + + return best_id; +} \ No newline at end of file diff --git a/common/speculative.cpp b/common/speculative.cpp index 3e83b0964c..136f2c1b1a 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -359,3 +359,116 @@ llama_tokens common_speculative_gen_draft( } return result; } + +llama_tokens mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + struct common_speculative_params params, + llama_token id_last, + int32_t n_past, + llama_seq_id seq_id) { + + int n_draft = params.n_draft; + + llama_tokens drafts; + drafts.reserve(n_draft); + + if (!smpl) return drafts; + + llama_batch mtp_batch = llama_batch_init(1, 0, 1); + mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN; + + llama_token current_input_id = id_last; + int32_t current_n_past = n_past; + + for (int i = 0; i < n_draft; ++i) { + mtp_batch.n_tokens = 0; + common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true); + + // Perform the MTP draft generation decode. This writes the MTP layer's + // KV state for the draft token into the cache. + if (llama_decode(ctx, mtp_batch) != 0) { + break; + } + + llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0); + + // Drafting stops if token probability drops below `p_min` to save compute. + const auto * cur_p = common_sampler_get_candidates(smpl, true); + if (cur_p && cur_p->size > 0) { + float prob = cur_p->data[0].p; + + if (prob < params.p_min) { + drafts.push_back(id_next); + current_n_past++; + break; + } + } + + drafts.push_back(id_next); + + current_input_id = id_next; + current_n_past++; + } + llama_batch_free(mtp_batch); + + // CRITICAL: Purge the metadata for the draft token we just wrote. + // This makes the physical cell available again for the main model's validation pass, + // preventing a cache state corruption where two cells map to the same logical position. + if (!drafts.empty()) { + llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past); + } + + return drafts; +} + + +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) { + if (batch.n_tokens == 0) { + return; + } + + LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens); + + llama_batch mtp_batch = batch; + if (is_prompt_warmup) { + mtp_batch.mtp_params.op_type = MTP_OP_WARMUP; + } else { + mtp_batch.mtp_params.op_type = MTP_OP_UPDATE_ACCEPTED; + } + + for (int i = 0; i < mtp_batch.n_tokens; ++i) { + mtp_batch.logits[i] = true; + } + llama_decode(ctx, mtp_batch); +} + +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +) { + if (ids.empty()) { + return; + } + + // Prepare a resized copy of the validation sinfo to match the number of accepted tokens. + // This sets up the context for a "forced sinfo" decode. + if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) { + return; + } + + // Build a new batch containing only the accepted tokens. + llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); + for (size_t i = 0; i < ids.size(); ++i) { + common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true); + } + + mtp_update_kv_cache(ctx, accepted_batch, false); + + // Clean up the forced state to not affect subsequent, normal decode calls. + llama_mtp_cancel_sinfo_update(ctx); + + llama_batch_free(accepted_batch); +} diff --git a/common/speculative.h b/common/speculative.h index e69d7aaa1e..a33c5a8b02 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -12,6 +12,12 @@ struct common_speculative_params { float p_min = 0.75f; // min probability required to accept a token in the draft }; +struct mtp_kv_update_data { + llama_token id; + int32_t n_past; + int32_t tok_idx; +}; + struct common_speculative * common_speculative_init( struct llama_context * ctx_tgt, struct llama_context * ctx_dft @@ -29,7 +35,40 @@ void common_speculative_add_replacement_tgt_dft( // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, - const llama_tokens & prompt, - llama_token id_last); + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt, + llama_token id_last); + +/** + * @brief Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture. + * + * This function performs a recursive generation loop using the MTP head (e.g., Eagle/NextN). + * It uses the fixed hidden state from the main model's last step and updates the MTP layer's + * internal KV cache autoregressively. + * + * @param smpl The sampler instance. + * @param ctx The llama context (shared between Main and MTP). + * @param params Speculative parameters (n_draft, p_min). + * @param id_last The last confirmed token ID from the main model. + * @param n_past The number of tokens in the validated past (start position for drafting). + * @param seq_id The sequence ID to use for drafting. + * + * @return std::vector The generated draft tokens. + */ +llama_tokens mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + struct common_speculative_params params, + llama_token id_last, + int32_t n_past, + llama_seq_id seq_id); + +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); + +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +); diff --git a/include/llama.h b/include/llama.h index f862930099..ce44c2d345 100644 --- a/include/llama.h +++ b/include/llama.h @@ -228,6 +228,17 @@ extern "C" { // - if not: only the last token is output // ) // + typedef enum { + MTP_OP_NONE, + MTP_OP_WARMUP, + MTP_OP_UPDATE_ACCEPTED, + MTP_OP_DRAFT_GEN, + } llama_mtp_op_type; + + typedef struct llama_mtp_params { + llama_mtp_op_type op_type; + } llama_mtp_params; + typedef struct llama_batch { int32_t n_tokens; @@ -237,6 +248,7 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" + llama_mtp_params mtp_params; } llama_batch; enum llama_model_kv_override_type { @@ -536,6 +548,8 @@ extern "C" { LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); + LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model); + // Functions to access the model's GGUF metadata scalar values // - The functions return the length of the string on success, or -1 on failure // - The output string is always null-terminated and cleared on failure @@ -1442,6 +1456,38 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); + // + // MTP + // + + LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + + /** + * @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo. + * This is used after speculative validation when only a subset of draft tokens are accepted. + * @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized. + * @return true on success. + */ + LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted); + + /** + * @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode. + * This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned. + * @return true on success. + */ + LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx); + + /** + * @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo. + */ + LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx); + + /** + * @brief Removes KV cache metadata for a specified sequence and token range. + * This makes the physical cells logically available again without deleting the tensor data. + */ + LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + #ifdef __cplusplus } #endif diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d0eaf317f7..2c949bc0c0 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2370,12 +2370,13 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // Changed to LLM_TENSOR_LAYER_REPEATING because we saved these under a blk with a non-negative id + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 386fab04ac..fc0e035b79 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -301,17 +301,17 @@ bool llama_batch_allocr::init( ok = false; } - if (!ok) { - LLAMA_LOG_ERROR( - "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" - " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" - " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n" - " it is required that the sequence positions remain consecutive: Y = X + 1\n", - __func__, s, s, p0, s, seq_pos_min(s)); + // if (!ok) { + // LLAMA_LOG_ERROR( + // "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" + // " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" + // " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n" + // " it is required that the sequence positions remain consecutive: Y = X + 1\n", + // __func__, s, s, p0, s, seq_pos_min(s)); - return false; - } - } + // return false; + // } + } if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s); @@ -874,13 +874,14 @@ struct llama_batch llama_batch_get_one( struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch batch = { - /*n_tokens =*/ 0, - /*tokens =*/ nullptr, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*.mtp_params =*/ { MTP_OP_NONE }, }; if (embd) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 8786d4ee3e..ba2d2a32c2 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -7,6 +7,7 @@ #include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-kv-cache.h" #include #include @@ -17,6 +18,13 @@ // // llama_context // +// Key for the graph cache. It contains all parameters that define the graph topology. + +struct llama_context_kv_cache_data { + llama_kv_cache::slot_info_vec_t last_main_model_sinfos; + llama_kv_cache::slot_info_vec_t resized_sinfo_for_force; + const llama_kv_cache::slot_info_vec_t * forced_sinfos = nullptr; +}; llama_context::llama_context( const llama_model & model, @@ -136,6 +144,9 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + kv_cache_data = new llama_context_kv_cache_data(); + + { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; @@ -477,6 +488,7 @@ llama_context::~llama_context() { // } // } ggml_opt_free(opt_ctx); + delete static_cast(kv_cache_data); } void llama_context::synchronize() { @@ -712,6 +724,10 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +ggml_tensor * llama_context::get_embeddings_tensor() { + return embd_tensor; +} + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -806,7 +822,8 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { +llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, + const llama_mtp_params & mtp_params) { if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; @@ -818,7 +835,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // the new graph parameters // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype); + const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); @@ -849,6 +866,13 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } } + if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation + if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) { + ret = GGML_STATUS_FAILED; + return nullptr; + } + } + // set the input data for the input tensors { //const auto t_start_us = ggml_time_us(); @@ -927,7 +951,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, { MTP_OP_NONE }); cparams.causal_attn = causal_attn_org; @@ -1035,6 +1059,8 @@ int llama_context::encode(const llama_batch & batch_inp) { int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + auto * kvd = static_cast(kv_cache_data); + if (!memory) { LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); return encode(batch_inp); @@ -1089,10 +1115,11 @@ int llama_context::decode(const llama_batch & batch_inp) { // handle any pending shifts/copies memory_update(false); - llama_memory_context_ptr mctx; + std::unique_ptr mctx; while (true) { - mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + mctx = this->initialize_decode_context(batch_inp, output_all); + if (!mctx) { return -2; } @@ -1109,6 +1136,12 @@ int llama_context::decode(const llama_batch & batch_inp) { } case LLAMA_MEMORY_STATUS_FAILED_PREPARE: { + if (kvd->forced_sinfos) { + LLAMA_LOG_ERROR("%s: Mismatch between ubatches and sinfos during reuse.\n", __func__); + + return -1; + } + if (!did_optimize) { did_optimize = true; @@ -1162,7 +1195,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, batch_inp.mtp_params); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module @@ -1209,71 +1242,81 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract logits if (t_logits && n_outputs > 0) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + // MTP operations that are purely for updating the KV cache + // (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor + // as a side effect of running the graph. If these logits are copied + // back to the main context buffer, they will overwrite the valid logits + // produced by the main model's pass, leading to incorrect sampling. + // This condition explicitly prevents that copy for cache-only operations. + if (batch_inp.mtp_params.op_type != MTP_OP_WARMUP && + batch_inp.mtp_params.op_type != MTP_OP_UPDATE_ACCEPTED) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); - float * logits_out = logits + n_outputs_prev*n_vocab; + float * logits_out = logits + n_outputs_prev*n_vocab; - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); - ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } } } // extract embeddings if (t_embd && n_outputs > 0) { - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); - GGML_ASSERT(backend_embd != nullptr); + if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(embd != nullptr); - float * embd_out = embd + n_outputs_prev*n_embd; + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - n_cls_out floats per sequence + auto & embd_seq_out = embd_seq; + const uint32_t n_cls_out = hparams.n_cls_out; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); } - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings (cleared before processing each batch) - auto & embd_seq_out = embd_seq; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_RANK: - { - // extract the rerank score - n_cls_out floats per sequence - auto & embd_seq_out = embd_seq; - - const uint32_t n_cls_out = hparams.n_cls_out; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_cls_out); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); - } + } } } @@ -1478,7 +1521,7 @@ ggml_cgraph * llama_context::graph_reserve( auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE }); res->reset(); @@ -1505,8 +1548,9 @@ ggml_cgraph * llama_context::graph_reserve( llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, - const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + const llama_memory_context_i * mctx, + llm_graph_type gtype, + const llama_mtp_params & mtp_params) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1519,12 +1563,28 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.mtp_params =*/ mtp_params, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, }; } +std::unique_ptr llama_context::mtp_memory_batch(const llama_batch& batch_inp) { + const auto& vocab = model.vocab; + const auto& hparams = model.hparams; + + const int64_t n_vocab = vocab.n_tokens(); + const int64_t n_embd = hparams.n_embd; + + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, false)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return nullptr; + } + + return memory->init_batch(*balloc, 1, false); +} + ggml_status llama_context::graph_compute( ggml_cgraph * gf, bool batched) { @@ -2266,7 +2326,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE }); res->reset(); @@ -3055,3 +3115,122 @@ void llama_opt_epoch( callback_train, callback_eval); } + +void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) { + ctx->draft_input_hidden_state = hidden_state; +} + +bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx) { + auto * kvd = static_cast(ctx->kv_cache_data); + const auto & last_sinfo = kvd->last_main_model_sinfos; + + if (last_sinfo.empty()) { + LLAMA_LOG_ERROR("%s: The main call sinfo is not available for warmup.\n", __func__); + return false; + } + + kvd->forced_sinfos = &last_sinfo; + return true; +} + + +bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) { + auto * kvd = static_cast(ctx->kv_cache_data); + const auto & last_sinfo = kvd->last_main_model_sinfos; + + if (last_sinfo.empty() || last_sinfo[0].idxs.empty()) { + LLAMA_LOG_ERROR("%s: The sinfo for the last main call is not available.", __func__); + return false; + } + + kvd->resized_sinfo_for_force = last_sinfo; + + kvd->resized_sinfo_for_force[0].idxs[0].resize(n_accepted); + + kvd->forced_sinfos = &kvd->resized_sinfo_for_force; + + return true; +} + +void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) { + auto * kvd = static_cast(ctx->kv_cache_data); + kvd->forced_sinfos = nullptr; +} + +void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (memory) { + static_cast(memory.get())->seq_rm(seq_id, p0, p1); + } +} + +void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + ctx->kv_cache_seq_rm(seq_id, p0, p1); +} + +/* + Initializes the memory context for a decode operation. + The logic follows a specific priority: + 1. Warmup: Always use a standard batch initialization. + 2. Forced S-Info (MTP Updates): If a specific KV cache layout is forced, use it. + 3. Default: Use a standard batch initialization, and if it's a main model pass, + save the resulting s-info for potential future reuse by MTP. +*/ +std::unique_ptr llama_context::initialize_decode_context(const llama_batch & batch_inp, const bool output_all) { + auto * kvd = static_cast(kv_cache_data); + std::unique_ptr mctx; + + if (cparams.warmup) { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + } else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) { + LLAMA_LOG_DEBUG("%s: Forcing sinfos, bypassing find_slot.\n", __func__); + mctx = static_cast(memory.get())->init_batch_with_sinfos( + *balloc, cparams.n_ubatch, *kvd->forced_sinfos, true + ); + } else { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + + if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { + if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) { + kvd->last_main_model_sinfos = static_cast(mctx.get())->get_sinfos(); + } else { + kvd->last_main_model_sinfos.clear(); + } + } + } + + return mctx; +} + + +bool llama_context::prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params) { + + const char * target_tensor_name = "result_embd_pooled"; + ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); + + const float * source_hidden_state = nullptr; + if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + source_hidden_state = this->embd; + } else { // MTP_OP_DRAFT_GEN + source_hidden_state = this->draft_input_hidden_state; + } + + if (source_hidden_state != nullptr && hidden_states_input != nullptr) { + const char * op_type; + if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + op_type = "MTP_UPDATE"; + } else { // MTP_OP_DRAFT_GEN + op_type = "DRAFT_GEN"; + } + + ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); + } else { + LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n", + __func__, target_tensor_name); + return false; + } + + return true; +} diff --git a/src/llama-context.h b/src/llama-context.h index c31101330e..3bf3483b08 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -32,6 +32,8 @@ struct llama_memory_breakdown_data { } }; +struct llama_context_kv_cache_data; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -69,6 +71,11 @@ struct llama_context { float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + ggml_tensor * get_embeddings_tensor(); + + const float * draft_input_hidden_state = nullptr; + + void * kv_cache_data = nullptr; void attach_threadpool( ggml_threadpool_t threadpool, @@ -100,6 +107,8 @@ struct llama_context { int32_t il_start, int32_t il_end); + void kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1); + // process a single ubatch with a specific graph type // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation @@ -108,7 +117,8 @@ struct llama_context { const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, - ggml_status & ret); + ggml_status & ret, + const llama_mtp_params & mtp_params); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); @@ -218,10 +228,21 @@ private: llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const; + llm_graph_type gtype, + const llama_mtp_params & mtp_params) const; llm_graph_cb graph_get_cb() const; + // Methods for MTP decode + std::unique_ptr initialize_decode_context(const llama_batch & batch_inp, const bool output_all); + + bool prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params); + + std::unique_ptr mtp_memory_batch(const llama_batch & batch_inp); + // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); @@ -251,6 +272,7 @@ private: // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE size_t embd_size = 0; // capacity (of floats) for embeddings float * embd = nullptr; + ggml_tensor * embd_tensor = nullptr; // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1d0d7197e1..75da81e2ff 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1254,6 +1254,26 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { return cur; } + +ggml_tensor * llm_graph_context::build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const { + auto inp = std::make_unique(); + ggml_tensor * cur = nullptr; + + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_name(inp->tokens, "mtp_inp_tokens"); + ggml_set_input(inp->tokens); + + cur = ggml_get_rows(ctx0, mtp_tok_embd, inp->tokens); + } else { + GGML_ABORT("fatal error: MTP update expects token IDs, not embeddings"); + } + + cb(cur, "mtp_inp_embd", -1); + res->add_input(std::move(inp)); + return cur; +} + ggml_tensor * llm_graph_context::build_inp_pos() const { auto inp = std::make_unique(hparams.n_pos_per_embd()); diff --git a/src/llama-graph.h b/src/llama-graph.h index 81ac329cc3..0a0de82c48 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -29,6 +29,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DRAFT, }; enum llm_ffn_op_type { @@ -102,6 +103,20 @@ protected: using llm_graph_input_ptr = std::unique_ptr; +class llm_graph_input_mtp_states : public llm_graph_input_i { +public: + llm_graph_input_mtp_states() = default; + virtual ~llm_graph_input_mtp_states() = default; + + void set_input(const llama_ubatch * /*ubatch*/) override {} + + bool can_reuse(const llm_graph_params & /*params*/) override { + return true; + } + + ggml_tensor * states = nullptr; +}; + class llm_graph_input_embd : public llm_graph_input_i { public: llm_graph_input_embd() = default; @@ -428,6 +443,7 @@ struct llm_graph_params { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + llama_mtp_params mtp_params; uint32_t n_outputs; @@ -476,6 +492,7 @@ struct llm_graph_params { cvec == other.cvec && loras == other.loras && cross == other.cross && + mtp_params.op_type == other.mtp_params.op_type && n_outputs == other.n_outputs; } }; @@ -690,6 +707,7 @@ struct llm_graph_context { // ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; + ggml_tensor * build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const; ggml_tensor * build_inp_pos() const; ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3186242d60..8f192482bf 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -542,6 +542,34 @@ llama_memory_context_ptr llama_kv_cache::init_batch( return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } +llama_memory_context_ptr llama_kv_cache::init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update) { + + if (sinfos.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + balloc.split_reset(); + std::vector ubatches; + while (true) { + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); + if (ubatch.n_tokens == 0) { + break; + } + ubatches.push_back(std::move(ubatch)); + } + + if (ubatches.size() != sinfos.size()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, sinfos, std::move(ubatches), is_inplace_update); +} + llama_memory_context_ptr llama_kv_cache::init_full() { return std::make_unique(this); } @@ -888,40 +916,61 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, } assert(res.s1 >= res.s0); + if (!res.empty()) { + std::string idxs_str; + for (const auto& vec : res.idxs) { + if (!vec.empty()) { + if (vec.size() > 8) { + idxs_str += " [" + std::to_string(vec.front()) + "..." + std::to_string(vec.back()) + " (" + std::to_string(vec.size()) + " cells)]"; + } else { + idxs_str += " ["; + for(size_t i = 0; i < vec.size(); ++i) { + idxs_str += std::to_string(vec[i]) + (i == vec.size() - 1 ? "" : ", "); + } + idxs_str += "]"; + } + } + } + } return res; } -void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { - // keep track of the max sequence position that we would overwrite with this ubatch - // for non-SWA cache, this would be always empty - llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - seq_pos_max_rm[s] = -1; - } +void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) { + // For "in-place" updates (MTP warmup/accept), we only update the tensor data. + // The cell metadata (logical position, sequence ID) has already been set + // by the main model's pass. We must skip all metadata modifications + // to prevent `pos_set` from asserting on an already-set cell. + if (!is_inplace_update) { + // keep track of the max sequence position that we would overwrite with this ubatch + // for non-SWA cache, this would be always empty + llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + seq_pos_max_rm[s] = -1; + } - assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); + assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { - for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { - const uint32_t i = s*sinfo.size() + ii; + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - auto & cells = v_cells[sinfo.strm[s]]; + auto & cells = v_cells[sinfo.strm[s]]; - const auto idx = sinfo.idxs[s][ii]; + const auto idx = sinfo.idxs[s][ii]; - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - cells.rm(idx); - } + cells.rm(idx); + } - cells.pos_set(idx, ubatch.pos[i]); + cells.pos_set(idx, ubatch.pos[i]); if (ubatch.is_pos_2d()) { llama_kv_cell_ext ext { @@ -931,29 +980,30 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & cells.ext_set(idx, ext); } - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } - } - // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence - // will be present in the cache. so we have to purge any position which is less than those we would overwrite - // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - if (seq_pos_max_rm[s] == -1) { - continue; - } + // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence + // will be present in the cache. so we have to purge any position which is less than those we would overwrite + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq_pos_max_rm[s] == -1) { + continue; + } - GGML_ASSERT(s < seq_to_stream.size()); + GGML_ASSERT(s < seq_to_stream.size()); - auto & cells = v_cells[seq_to_stream[s]]; + auto & cells = v_cells[seq_to_stream[s]]; - if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { - LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { + LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); - seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + } } } @@ -2010,7 +2060,8 @@ llama_kv_cache_context::llama_kv_cache_context( llama_kv_cache_context::llama_kv_cache_context( llama_kv_cache * kv, llama_kv_cache::slot_info_vec_t sinfos, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) { + std::vector ubatches, + bool is_inplace_update) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)), is_inplace_update(is_inplace_update) { } llama_kv_cache_context::~llama_kv_cache_context() = default; @@ -2035,7 +2086,7 @@ bool llama_kv_cache_context::apply() { return true; } - kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); + kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur], is_inplace_update); n_kv = kv->get_n_kv(sinfos[i_cur]); return true; @@ -2098,3 +2149,7 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } + +void llama_kv_cache_context::set_sinfos(llama_kv_cache_context::slot_info_vec_t new_sinfos) { + sinfos = new_sinfos; +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 1868f11857..065a0cc655 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -118,6 +118,12 @@ public: llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; + + llama_memory_context_ptr init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update); llama_memory_context_ptr init_full() override; @@ -182,7 +188,7 @@ public: slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] - void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); + void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update = false); // // input API @@ -309,7 +315,8 @@ public: llama_kv_cache_context( llama_kv_cache * kv, slot_info_vec_t sinfos, - std::vector ubatches); + std::vector ubatches, + bool is_inplace_update = false); virtual ~llama_kv_cache_context(); @@ -355,6 +362,10 @@ public: void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_sinfos(slot_info_vec_t new_sinfos); + + const slot_info_vec_t & get_sinfos() const { return sinfos; } + private: llama_memory_status status; @@ -387,4 +398,6 @@ private: // a heuristic, to avoid attending the full cache if it is not yet utilized // as the cache gets filled, the benefit from this heuristic disappears int32_t n_kv; + + bool is_inplace_update = false; }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d2270e8f2d..36378440e6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1720,8 +1720,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + hparams.n_layer_kv_from_start = hparams.n_layer; switch (hparams.n_layer) { case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) @@ -5054,10 +5053,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // but only PROCESS up to last layer (skipping final NextN layer) in forward pass for (int i = 0; i < n_layer; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } auto & layer = layers[i]; @@ -7642,7 +7637,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } // add on pooling layer - llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + if (params.mtp_params.op_type == MTP_OP_NONE) { + llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + } // if the gguf model was converted with --sentence-transformers-dense-modules // there will be two additional dense projection layers @@ -7733,6 +7730,10 @@ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) return nullptr; } +int32_t llama_model_n_nextn_layer(const llama_model * model) { + return model->hparams.nextn_predict_layers; +} + // deprecated int32_t llama_n_ctx_train(const llama_model * model) { return llama_model_n_ctx_train(model); diff --git a/src/models/glm4-moe.cpp b/src/models/glm4-moe.cpp index 003f70f739..c75868b16e 100644 --- a/src/models/glm4-moe.cpp +++ b/src/models/glm4-moe.cpp @@ -2,169 +2,308 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); ggml_tensor * cur; - ggml_tensor * inpL; - inpL = build_inp_embd(model.tok_embd); + if (params.mtp_params.op_type != MTP_OP_NONE) { + ggml_tensor* hidden_states_from_main_model; - bool use_mrope = hparams.use_mrope(); - if (ubatch.embd && !use_mrope) { - // unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results - GGML_ABORT("This GGUF does not support multimodal. Please reconvert it."); - } + if (params.mtp_params.op_type == MTP_OP_WARMUP || params.mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - // Only process up to last layer (skip final NextN layer) - // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { - ggml_tensor * inpSA = inpL; - - // Pre-attention norm - cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - // Apply Q/K norm if available (GLM-4.5 355B variant) - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - } - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); - } - - if (use_mrope) { - Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - } else { - // Normal RoPE - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, - rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, - rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - } - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - } - if (il == n_transformer_layers - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // Post-attention norm - cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "post_attn_norm", il); - - // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) - if (static_cast(il) < hparams.n_layer_dense_lead) { - // Dense FFN layer - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); } else { - // Process routed experts using existing MoE infrastructure - ggml_tensor * routed_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, - (llama_expert_gating_func_type) hparams.expert_gating_func, - il); - cb(routed_out, "ffn_moe_out", il); + hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd); + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); - // Process shared expert on original input - ggml_tensor * shared_out = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(shared_out, "ffn_shexp_out", il); - - // Final output: routed_output + shared_output - cur = ggml_add(ctx0, routed_out, shared_out); - cb(cur, "ffn_out", il); + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); } - cur = ggml_add(ctx0, cur, ffn_inp); - cur = build_cvec(cur, il); - cb(cur, "l_out", il); + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head); + } else { + ggml_tensor * inpL; - // input for next layer - inpL = cur; + inpL = build_inp_embd(model.tok_embd); + + bool use_mrope = hparams.use_mrope(); + if (ubatch.embd && !use_mrope) { + // unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results + GGML_ABORT("This GGUF does not support multimodal. Please reconvert it."); + } + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + if (use_mrope) { + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } else { + // Normal RoPE + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, + rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); + + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; } - cur = inpL; - cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); + ggml_build_forward_expand(gf, res->t_logits); +} + + +ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, int64_t n_embd_head) { + ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings); + + const int il = hparams.n_layer - 1; + ggml_tensor * sum_node = ggml_sum(ctx0, embd_copy); + + ggml_set_name(sum_node, "mtp_input_sum"); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens); + + ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); + ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); + + // now proceed through last layer (skipped in main model) + ggml_tensor * inpSA = cur; + // Pre-attention norm for the MTP block + cur = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur); + if (mtp_layer.bq) Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur); + if (mtp_layer.bk) Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur); + if (mtp_layer.bv) Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (mtp_layer.attn_q_norm) { + Qcur = build_norm(Qcur, mtp_layer.attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (mtp_layer.attn_k_norm) { + Kcur = build_norm(Kcur, mtp_layer.attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + mtp_layer.wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + + cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il); + + // moe ffn for nextn block + { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + mtp_layer.ffn_gate_inp, + mtp_layer.ffn_up_exps, + mtp_layer.ffn_gate_exps, + mtp_layer.ffn_down_exps, + mtp_layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + mtp_layer.ffn_up_shexp, NULL, NULL, + mtp_layer.ffn_gate_shexp, NULL, NULL, + mtp_layer.ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); + + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); + cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); + + return cur; } diff --git a/src/models/models.h b/src/models/models.h index ffb36acc61..a4373b7faa 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -220,6 +220,8 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params); + + ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, int64_t n_embd_head); }; struct llm_build_gpt2 : public llm_graph_context { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index cde34e6533..925b9b805d 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -80,6 +80,7 @@ struct server_slot { mtmd_context * mctx = nullptr; common_speculative * spec = nullptr; + bool has_mtp = false; std::unique_ptr task; std::unique_ptr task_prev; // used for debugging @@ -206,7 +207,7 @@ struct server_slot { bool need_embd() const { GGML_ASSERT(task); - return server_task_type_need_embd(task->type); + return server_task_type_need_embd(task->type) || has_mtp; } bool need_logits() const { @@ -220,7 +221,8 @@ struct server_slot { bool can_split() const { return !need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); } bool can_batch_with(server_slot & other_slot) const { @@ -252,7 +254,7 @@ struct server_slot { } bool can_speculate() const { - return ctx_dft; + return (ctx_dft || has_mtp); } void add_token(const completion_token_output & token) { @@ -769,6 +771,18 @@ struct server_context_impl { } } + // if model has MTP and no draft model is specified... + else if (llama_model_n_nextn_layer(model) > 0 && params_base.mtp) { + SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); + slot.has_mtp = true; + + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); + + SRV_INF("%s (n_max=%d)\n", "MTP needs embeddings on decode, enabling", params_base.speculative.n_max); + llama_set_embeddings(ctx, true); + } + SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); slot.callback_on_release = [this](int) { @@ -1971,12 +1985,34 @@ struct server_context_impl { GGML_ABORT("not supported by multimodal"); } + llama_tokens draft; + struct common_speculative_params params_spec; params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; params_spec.p_min = slot.task->params.speculative.p_min; - const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); + + if (slot.ctx_dft) { + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; + } else { + params_spec.n_reuse = 0; + } + + if (slot.has_mtp) { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1)); + + draft = mtp_speculative_gen_draft( + slot.smpl, + ctx, + params_spec, + slot.sampled, + slot.prompt.n_tokens(), + slot.id + ); + } + else { + const llama_tokens& cached_text_tokens = slot.prompt.tokens.get_text_tokens(); + draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); + } // add the sampled token to the batch slot.i_batch_dft.push_back(batch.n_tokens); @@ -2583,6 +2619,21 @@ struct server_context_impl { continue; // continue loop of n_batch } + if (slot_batched && slot_batched->has_mtp && + (slot_batched->state == SLOT_STATE_PROCESSING_PROMPT || slot_batched->state == SLOT_STATE_DONE_PROMPT)) { + + // Prepare the context to reuse the exact sinfo layout (including multiple u-batches) + // from the main model's prompt processing pass. This ensures the MTP layer's + // KV cache is perfectly aligned. + if (llama_mtp_prepare_sinfo_for_warmup(ctx)) { + mtp_update_kv_cache(ctx, batch_view, true); + // Clean up the forced state to not affect subsequent decodes. + llama_mtp_cancel_sinfo_update(ctx); + } else { + LOG_ERR("%s: Failed to prepare the MTP for warmup.", __func__); + } + } + // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; @@ -2702,6 +2753,16 @@ struct server_context_impl { slot.i_batch_dft.clear(); slot.drafted.clear(); + if (slot.has_mtp) { + if (!ids.empty()) { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); + } else { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0)); + } + + mtp_accept_tokens(ctx, ids, slot.prompt.n_tokens(), slot.id); + } + slot.n_decoded += ids.size(); slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; From 38c91187f9b9677f659a1805b5224dde7cf61a77 Mon Sep 17 00:00:00 2001 From: samuel Date: Wed, 10 Dec 2025 12:33:10 -0300 Subject: [PATCH 2/6] speculative: optimize graph reuse for GLM-4.5 --- src/models/glm4-moe.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/models/glm4-moe.cpp b/src/models/glm4-moe.cpp index c75868b16e..53006e1d62 100644 --- a/src/models/glm4-moe.cpp +++ b/src/models/glm4-moe.cpp @@ -198,11 +198,9 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, int64_t n_embd_head) { ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings); + cb(embd_copy, "mtp_embd_copy", -1); const int il = hparams.n_layer - 1; - ggml_tensor * sum_node = ggml_sum(ctx0, embd_copy); - - ggml_set_name(sum_node, "mtp_input_sum"); ggml_tensor * inp_pos = build_inp_pos(); auto * inp_attn = build_attn_inp_kv(); @@ -212,6 +210,7 @@ ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); + cb(combined, "mtp_concat", il); ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // now proceed through last layer (skipped in main model) @@ -269,6 +268,7 @@ ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, } ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "mtp_ffn_inp", il); cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il); @@ -302,6 +302,7 @@ ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "mtp_ffn_out_resid", il); cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); From d9576dd0377ee410b69d449b2b237835753e744c Mon Sep 17 00:00:00 2001 From: samuel Date: Wed, 10 Dec 2025 22:54:27 -0300 Subject: [PATCH 3/6] glm4: add MTP weight fallback for GLM-4.6 compatibility GLM-4.6 models exclude specific MTP tensors (`embed_tokens` and `shared_head_head`), implying weight tying with the main model. Previously, this caused a crash when building the graph. This commit adds a fallback mechanism to use the main model's token embeddings and output head when the MTP-specific tensors are missing. --- src/models/glm4-moe.cpp | 22 ++++++++++++++++++---- src/models/models.h | 2 +- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/models/glm4-moe.cpp b/src/models/glm4-moe.cpp index 53006e1d62..0a491832f0 100644 --- a/src/models/glm4-moe.cpp +++ b/src/models/glm4-moe.cpp @@ -32,7 +32,8 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap const int il_mtp = hparams.n_layer - 1; const auto & mtp_layer = model.layers[il_mtp]; - res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head); + res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head, model); + } else { ggml_tensor * inpL; @@ -196,7 +197,8 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap } -ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, int64_t n_embd_head) { +ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, + int64_t n_embd_head, const llama_model & model) { ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings); cb(embd_copy, "mtp_embd_copy", -1); @@ -204,7 +206,13 @@ ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * inp_pos = build_inp_pos(); auto * inp_attn = build_attn_inp_kv(); - ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens); + + // If nextn.embed_tokens is missing (GLM-4.6), use model.tok_embd + ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens; + if (mtp_embd_weights == nullptr) { + mtp_embd_weights = model.tok_embd; + } + ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); @@ -304,7 +312,13 @@ ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "mtp_ffn_out_resid", il); cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); - cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); + + // If nextn.shared_head_head is missing (GLM-4.6), use model.output (Main LM Head) + ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head; + if (mtp_head_weights == nullptr) { + mtp_head_weights = model.output; + } + cur = build_lora_mm(mtp_head_weights, cur); return cur; } diff --git a/src/models/models.h b/src/models/models.h index a4373b7faa..21a8f18717 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -221,7 +221,7 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, int64_t n_embd_head); + ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, int64_t n_embd_head, const llama_model & model); }; struct llm_build_gpt2 : public llm_graph_context { From a3e29da02add9759a24f1daf39a275427545f434 Mon Sep 17 00:00:00 2001 From: samuel Date: Fri, 19 Dec 2025 20:41:35 -0300 Subject: [PATCH 4/6] glm-moe: allow skipping MTP tensor loading to save VRAM Adds a new `mtp` boolean to `llama_model_params`. When set to false (default): 1. The loader skips loading MTP-specific tensors (NextN layers) using `TENSOR_SKIP`. 2. The KV cache size calculation excludes the MTP layer (`n_layer_kv_from_start`). This reduces VRAM usage and load time for users running GLM-4.5/4.6 in standard generation mode. --- common/common.cpp | 1 + common/common.h | 2 +- include/llama.h | 1 + src/llama-context.h | 2 +- src/llama-graph.h | 2 +- src/llama-model.cpp | 19 ++++++++++++++++--- 6 files changed, 21 insertions(+), 6 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d4e8c7405e..7c1297574a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1351,6 +1351,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.check_tensors = params.check_tensors; mparams.use_extra_bufts = !params.no_extra_bufts; mparams.no_host = params.no_host; + mparams.mtp = params.mtp; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; diff --git a/common/common.h b/common/common.h index 6c2f1dc686..40d7689872 100644 --- a/common/common.h +++ b/common/common.h @@ -430,7 +430,7 @@ struct common_params { bool no_op_offload = false; // globally disable offload host tensor operations to device bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) bool no_host = false; // bypass host buffer allowing extra buffers to be used - bool mtp = false; // use mtp is supported + bool mtp = false; // enable MTP if supported by the model bool single_turn = false; // single turn chat conversation diff --git a/include/llama.h b/include/llama.h index ce44c2d345..0428a9085b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -326,6 +326,7 @@ extern "C" { bool use_extra_bufts; // use extra buffer types (used for weight repacking) bool no_host; // bypass host buffer allowing extra buffers to be used bool no_alloc; // only load metadata and simulate memory allocations + bool mtp; // use mtp if is supported by the Model }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations diff --git a/src/llama-context.h b/src/llama-context.h index 3bf3483b08..392796f7a3 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -337,4 +337,4 @@ private: mutable int32_t n_eval = 0; // number of eval calls mutable int32_t n_reused = 0; // number of times the previous graph was reused -}; +}; \ No newline at end of file diff --git a/src/llama-graph.h b/src/llama-graph.h index 0a0de82c48..0c9d7d1508 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -860,4 +860,4 @@ struct llm_graph_context { }; // TODO: better name -int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); +int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); \ No newline at end of file diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 36378440e6..cef7a0bd96 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1720,8 +1720,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - hparams.n_layer_kv_from_start = hparams.n_layer; - + if (params.mtp) { + // Include MTP layers in KV cache if MTP is enabled + hparams.n_layer_kv_from_start = hparams.n_layer; + } + else { + // Otherwise exclude to save memory + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + } switch (hparams.n_layer) { case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) @@ -5050,9 +5056,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // Load ALL tensors including NextN layer to satisfy total tensor count - // but only PROCESS up to last layer (skipping final NextN layer) in forward pass + // but skip loading data for NextN layers if MTP is disabled to save VRAM for (int i = 0; i < n_layer; ++i) { int flags = 0; + // Skip loading MTP layers if the feature is disabled + if (!params.mtp) { + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + flags |= TENSOR_SKIP; + } + } auto & layer = layers[i]; @@ -7673,6 +7685,7 @@ llama_model_params llama_model_default_params() { /*.use_extra_bufts =*/ true, /*.no_host =*/ false, /*.no_alloc =*/ false, + /*.mtp =*/ false, }; return result; From a8dc54672caf30fae6bdbd7214d8b0bc03b9b999 Mon Sep 17 00:00:00 2001 From: samuel Date: Fri, 19 Dec 2025 21:57:15 -0300 Subject: [PATCH 5/6] common: simplify speculative sampling to greedy-only for performance Removes heavy penalty checks (repetition, frequency, presence, DRY) from `common_sampler_sample_speculative`. The specialized speculative sampler now uses a pure ArgMax (Greedy) approach. This significantly reduces CPU overhead during the drafting phase, which improves overall tokens per second. --- common/sampling.cpp | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 3254f8d66c..c33d58ae5e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -670,26 +670,13 @@ std::vector common_sampler_types_from_chars(const std::stri /** * Specialized sampling for speculative drafting. * - * Prioritizes performance by using a direct ArgMax loop (Greedy) when no - * penalties (repetition, frequency, presence, DRY) are configured. - * Falls back to the full sampler chain if penalties are active to prevent - * generative loops or adhere to constraints. + * Prioritizes performance by using a direct ArgMax loop (Greedy). + * Penalties and complex sampling logic are bypassed to minimize + * drafting latency. */ llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) { const auto & params = gsmpl->params; - bool use_heavy_sampler = - (params.penalty_last_n > 0 && ( - params.penalty_repeat != 1.0f || - params.penalty_freq != 0.0f || - params.penalty_present != 0.0f - )) || - (params.dry_allowed_length > 0 && params.dry_multiplier != 0.0f); - - if (use_heavy_sampler) { - return common_sampler_sample(gsmpl, ctx, idx, false); - } - float * logits = llama_get_logits_ith(ctx, idx); const int n_vocab = llama_n_vocab(llama_model_get_vocab(llama_get_model(ctx))); From d10a5a4a5bf791108145b9e61ab50cf9b27b2359 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Sun, 21 Dec 2025 17:53:27 -0500 Subject: [PATCH 6/6] clean up mtp sample typing after rebase --- common/sampling.cpp | 2 +- common/sampling.h | 2 ++ common/speculative.cpp | 4 ++-- common/speculative.h | 4 ++-- tools/server/server-context.cpp | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index c33d58ae5e..27b2a082b0 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -691,4 +691,4 @@ llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, str } return best_id; -} \ No newline at end of file +} diff --git a/common/sampling.h b/common/sampling.h index c7101032f2..90c2401c2f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -115,3 +115,5 @@ struct common_sampler_deleter { }; typedef std::unique_ptr common_sampler_ptr; + +llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx); diff --git a/common/speculative.cpp b/common/speculative.cpp index 136f2c1b1a..548394bbe8 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -361,8 +361,8 @@ llama_tokens common_speculative_gen_draft( } llama_tokens mtp_speculative_gen_draft( - struct common_sampler* smpl, - struct llama_context* ctx, + struct common_sampler * smpl, + struct llama_context * ctx, struct common_speculative_params params, llama_token id_last, int32_t n_past, diff --git a/common/speculative.h b/common/speculative.h index a33c5a8b02..d22a752d3f 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -57,8 +57,8 @@ llama_tokens common_speculative_gen_draft( * @return std::vector The generated draft tokens. */ llama_tokens mtp_speculative_gen_draft( - struct common_sampler* smpl, - struct llama_context* ctx, + struct common_sampler * smpl, + struct llama_context * ctx, struct common_speculative_params params, llama_token id_last, int32_t n_past, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 925b9b805d..dca005da35 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2001,7 +2001,7 @@ struct server_context_impl { llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1)); draft = mtp_speculative_gen_draft( - slot.smpl, + slot.smpl.get(), ctx, params_spec, slot.sampled,