mtp-batch(fix): Correctly advance cache head and add MTP documentation
This commit is contained in:
parent
b4cbe030ac
commit
4bcc9e261e
|
|
@ -436,10 +436,13 @@ void mtp_accept_tokens(
|
|||
return;
|
||||
}
|
||||
|
||||
// Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
|
||||
// This sets up the context for a "forced sinfo" decode.
|
||||
if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a new batch containing only the accepted tokens.
|
||||
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
|
||||
|
|
@ -447,6 +450,7 @@ void mtp_accept_tokens(
|
|||
|
||||
mtp_update_kv_cache(ctx, accepted_batch, false);
|
||||
|
||||
// Clean up the forced state to not affect subsequent, normal decode calls.
|
||||
llama_mtp_cancel_sinfo_update(ctx);
|
||||
|
||||
llama_batch_free(accepted_batch);
|
||||
|
|
|
|||
|
|
@ -1466,14 +1466,36 @@ extern "C" {
|
|||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
//
|
||||
// MTP
|
||||
//
|
||||
|
||||
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
|
||||
|
||||
/**
|
||||
* @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo.
|
||||
* This is used after speculative validation when only a subset of draft tokens are accepted.
|
||||
* @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized.
|
||||
* @return true on success.
|
||||
*/
|
||||
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
|
||||
|
||||
/**
|
||||
* @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode.
|
||||
* This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned.
|
||||
* @return true on success.
|
||||
*/
|
||||
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);
|
||||
|
||||
/**
|
||||
* @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo.
|
||||
*/
|
||||
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);
|
||||
|
||||
/**
|
||||
* @brief Removes KV cache metadata for a specified sequence and token range.
|
||||
* This makes the physical cells logically available again without deleting the tensor data.
|
||||
*/
|
||||
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
|||
|
|
@ -30,6 +30,10 @@ struct llama_context {
|
|||
|
||||
~llama_context();
|
||||
|
||||
// The llama_context manages significant resources (GPU memory, file handles, PImpl data)
|
||||
// and is fundamentally a non-copyable, non-movable object. Deleting these special
|
||||
// member functions enforces this rule and is also technically required to allow the
|
||||
// PImpl pattern (via unique_ptr or void*) with an incomplete type in the header.
|
||||
llama_context(const llama_context &) = delete;
|
||||
llama_context & operator=(const llama_context &) = delete;
|
||||
llama_context(llama_context &&) = delete;
|
||||
|
|
|
|||
|
|
@ -977,6 +977,10 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||
}
|
||||
|
||||
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) {
|
||||
// For "in-place" updates (MTP warmup/accept), we only update the tensor data.
|
||||
// The cell metadata (logical position, sequence ID) has already been set
|
||||
// by the main model's pass. We must skip all metadata modifications
|
||||
// to prevent `pos_set` from asserting on an already-set cell.
|
||||
if (!is_inplace_update) {
|
||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||
// for non-SWA cache, this would be always empty
|
||||
|
|
@ -995,17 +999,12 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
|||
|
||||
const auto idx = sinfo.idxs[s][ii];
|
||||
|
||||
if (!is_inplace_update) {
|
||||
if (!cells.is_empty(idx)) {
|
||||
assert(cells.seq_count(idx) == 1);
|
||||
|
||||
const llama_seq_id seq_id = cells.seq_get(idx);
|
||||
const llama_pos pos = cells.pos_get(idx);
|
||||
|
||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||
|
||||
cells.rm(idx);
|
||||
}
|
||||
if (!cells.is_empty(idx)) {
|
||||
assert(cells.seq_count(idx) == 1);
|
||||
const llama_seq_id seq_id = cells.seq_get(idx);
|
||||
const llama_pos pos = cells.pos_get(idx);
|
||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||
cells.rm(idx);
|
||||
}
|
||||
|
||||
cells.pos_set(idx, ubatch.pos[i]);
|
||||
|
|
@ -1029,19 +1028,17 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
|||
auto & cells = v_cells[seq_to_stream[s]];
|
||||
|
||||
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
||||
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
||||
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
|
||||
|
||||
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// move the head at the end of the slot
|
||||
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||
auto & head = v_heads[sinfo.strm[s]];
|
||||
// move the head at the end of the slot
|
||||
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||
auto & head = v_heads[sinfo.strm[s]];
|
||||
|
||||
head = sinfo.idxs[s].back() + 1;
|
||||
}
|
||||
head = sinfo.idxs[s].back() + 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue