From 4bcc9e261ef57ee4cfaa65d06bcd0fcdeacf7797 Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 11 Oct 2025 18:51:22 -0300 Subject: [PATCH] mtp-batch(fix): Correctly advance cache head and add MTP documentation --- common/speculative.cpp | 4 ++++ include/llama.h | 22 ++++++++++++++++++++++ src/llama-context.h | 4 ++++ src/llama-kv-cache-unified.cpp | 33 +++++++++++++++------------------ 4 files changed, 45 insertions(+), 18 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 02eca967ca..a7a4042682 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -436,10 +436,13 @@ void mtp_accept_tokens( return; } + // Prepare a resized copy of the validation sinfo to match the number of accepted tokens. + // This sets up the context for a "forced sinfo" decode. if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) { return; } + // Build a new batch containing only the accepted tokens. llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); for (size_t i = 0; i < ids.size(); ++i) { common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true); @@ -447,6 +450,7 @@ void mtp_accept_tokens( mtp_update_kv_cache(ctx, accepted_batch, false); + // Clean up the forced state to not affect subsequent, normal decode calls. llama_mtp_cancel_sinfo_update(ctx); llama_batch_free(accepted_batch); diff --git a/include/llama.h b/include/llama.h index 89c8510310..0b15d4bf1c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1466,14 +1466,36 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); + // + // MTP + // + LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + /** + * @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo. + * This is used after speculative validation when only a subset of draft tokens are accepted. + * @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized. + * @return true on success. + */ LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted); + /** + * @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode. + * This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned. + * @return true on success. + */ LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx); + /** + * @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo. + */ LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx); + /** + * @brief Removes KV cache metadata for a specified sequence and token range. + * This makes the physical cells logically available again without deleting the tensor data. + */ LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); #ifdef __cplusplus diff --git a/src/llama-context.h b/src/llama-context.h index ab854c1af1..4d77d5d81a 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -30,6 +30,10 @@ struct llama_context { ~llama_context(); + // The llama_context manages significant resources (GPU memory, file handles, PImpl data) + // and is fundamentally a non-copyable, non-movable object. Deleting these special + // member functions enforces this rule and is also technically required to allow the + // PImpl pattern (via unique_ptr or void*) with an incomplete type in the header. llama_context(const llama_context &) = delete; llama_context & operator=(const llama_context &) = delete; llama_context(llama_context &&) = delete; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 90ee8f726e..8d9b1f631f 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -977,6 +977,10 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) { + // For "in-place" updates (MTP warmup/accept), we only update the tensor data. + // The cell metadata (logical position, sequence ID) has already been set + // by the main model's pass. We must skip all metadata modifications + // to prevent `pos_set` from asserting on an already-set cell. if (!is_inplace_update) { // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty @@ -995,17 +999,12 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u const auto idx = sinfo.idxs[s][ii]; - if (!is_inplace_update) { - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); - - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); - - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - - cells.rm(idx); - } + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + cells.rm(idx); } cells.pos_set(idx, ubatch.pos[i]); @@ -1029,19 +1028,17 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u auto & cells = v_cells[seq_to_stream[s]]; if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { - LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", - __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); } } + } - // move the head at the end of the slot - for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { - auto & head = v_heads[sinfo.strm[s]]; + // move the head at the end of the slot + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & head = v_heads[sinfo.strm[s]]; - head = sinfo.idxs[s].back() + 1; - } + head = sinfo.idxs[s].back() + 1; } }