mtp-batch (fix): prevent mtp draft from polluting the cache

This commit is contained in:
samuel 2025-10-09 22:27:18 -03:00
parent 5e1d719bef
commit 6f74ba3807
5 changed files with 38 additions and 5 deletions

View File

@ -374,6 +374,8 @@ llama_token mtp_speculative_gen_draft(
return -1;
}
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
const llama_pos draft_pos = n_past;
const llama_seq_id draft_seq_id = 0;
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
mtp_batch.update_mtp_kv = false;
@ -387,6 +389,8 @@ llama_token mtp_speculative_gen_draft(
llama_decode(ctx, mtp_batch);
llama_batch_free(mtp_batch);
llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1);
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);

View File

@ -1460,9 +1460,13 @@ extern "C" {
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
#ifdef __cplusplus
}
#endif

View File

@ -3105,6 +3105,20 @@ void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float
ctx->draft_input_hidden_state = hidden_state;
}
bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx) {
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
const auto & last_sinfo = kvd->last_main_model_sinfos;
if (last_sinfo.empty()) {
LLAMA_LOG_ERROR("%s: The main call sinfo is not available for warmup.\n", __func__);
return false;
}
kvd->forced_sinfos = &last_sinfo;
return true;
}
bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) {
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
const auto & last_sinfo = kvd->last_main_model_sinfos;
@ -3126,4 +3140,14 @@ bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_acc
void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) {
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
kvd->forced_sinfos = nullptr;
}
}
void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
if (memory) {
static_cast<llama_kv_cache_unified *>(memory.get())->seq_rm(seq_id, p0, p1);
}
}
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
ctx->kv_cache_seq_rm(seq_id, p0, p1);
}

View File

@ -100,6 +100,8 @@ struct llama_context {
int32_t il_start,
int32_t il_end);
void kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1);
// process a single ubatch with a specific graph type
// if memory_context is provided, it will be applied first to the context's memory
// ret contains the status of the graph computation

View File

@ -3520,11 +3520,10 @@ struct server_context {
needs_mtp_warmup = true;
}
}
if (needs_mtp_warmup) {
if (llama_mtp_prepare_sinfo_for_update(ctx, batch_view.n_tokens)) {
if (llama_mtp_prepare_sinfo_for_warmup(ctx)) {
mtp_update_kv_cache(ctx, batch_view, true);
llama_mtp_cancel_sinfo_update(ctx);
} else {
LOG_ERR("%s: Failed to prepare the MTP symphony for warmup.", __func__);