mtp-batch(fix): Correctly advance cache head and add MTP documentation

This commit is contained in:
samuel 2025-10-11 18:51:22 -03:00
parent b4cbe030ac
commit 4bcc9e261e
4 changed files with 45 additions and 18 deletions

View File

@ -436,10 +436,13 @@ void mtp_accept_tokens(
return;
}
// Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
// This sets up the context for a "forced sinfo" decode.
if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) {
return;
}
// Build a new batch containing only the accepted tokens.
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
for (size_t i = 0; i < ids.size(); ++i) {
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
@ -447,6 +450,7 @@ void mtp_accept_tokens(
mtp_update_kv_cache(ctx, accepted_batch, false);
// Clean up the forced state to not affect subsequent, normal decode calls.
llama_mtp_cancel_sinfo_update(ctx);
llama_batch_free(accepted_batch);

View File

@ -1466,14 +1466,36 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
//
// MTP
//
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
/**
* @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo.
* This is used after speculative validation when only a subset of draft tokens are accepted.
* @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized.
* @return true on success.
*/
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
/**
* @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode.
* This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned.
* @return true on success.
*/
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);
/**
* @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo.
*/
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);
/**
* @brief Removes KV cache metadata for a specified sequence and token range.
* This makes the physical cells logically available again without deleting the tensor data.
*/
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
#ifdef __cplusplus

View File

@ -30,6 +30,10 @@ struct llama_context {
~llama_context();
// The llama_context manages significant resources (GPU memory, file handles, PImpl data)
// and is fundamentally a non-copyable, non-movable object. Deleting these special
// member functions enforces this rule and is also technically required to allow the
// PImpl pattern (via unique_ptr or void*) with an incomplete type in the header.
llama_context(const llama_context &) = delete;
llama_context & operator=(const llama_context &) = delete;
llama_context(llama_context &&) = delete;

View File

@ -977,6 +977,10 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
}
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) {
// For "in-place" updates (MTP warmup/accept), we only update the tensor data.
// The cell metadata (logical position, sequence ID) has already been set
// by the main model's pass. We must skip all metadata modifications
// to prevent `pos_set` from asserting on an already-set cell.
if (!is_inplace_update) {
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
@ -995,17 +999,12 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
const auto idx = sinfo.idxs[s][ii];
if (!is_inplace_update) {
if (!cells.is_empty(idx)) {
assert(cells.seq_count(idx) == 1);
const llama_seq_id seq_id = cells.seq_get(idx);
const llama_pos pos = cells.pos_get(idx);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
cells.rm(idx);
}
if (!cells.is_empty(idx)) {
assert(cells.seq_count(idx) == 1);
const llama_seq_id seq_id = cells.seq_get(idx);
const llama_pos pos = cells.pos_get(idx);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
cells.rm(idx);
}
cells.pos_set(idx, ubatch.pos[i]);
@ -1029,19 +1028,17 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
auto & cells = v_cells[seq_to_stream[s]];
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
}
}
}
// move the head at the end of the slot
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
auto & head = v_heads[sinfo.strm[s]];
// move the head at the end of the slot
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
auto & head = v_heads[sinfo.strm[s]];
head = sinfo.idxs[s].back() + 1;
}
head = sinfo.idxs[s].back() + 1;
}
}