From 6f74ba38070d62d37bc0fb71ce9871e1a4ffabcc Mon Sep 17 00:00:00 2001 From: samuel Date: Thu, 9 Oct 2025 22:27:18 -0300 Subject: [PATCH] mtp-batch (fix): prevent mtp draft from polluting the cache --- common/speculative.cpp | 4 ++++ include/llama.h | 6 +++++- src/llama-context.cpp | 26 +++++++++++++++++++++++++- src/llama-context.h | 2 ++ tools/server/server.cpp | 5 ++--- 5 files changed, 38 insertions(+), 5 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index f71982f9e4..8249a3a52c 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -374,6 +374,8 @@ llama_token mtp_speculative_gen_draft( return -1; } llama_batch mtp_batch = llama_batch_init(1, 0, 1); + const llama_pos draft_pos = n_past; + const llama_seq_id draft_seq_id = 0; common_batch_add(mtp_batch, id_last, n_past, {0}, true); mtp_batch.update_mtp_kv = false; @@ -387,6 +389,8 @@ llama_token mtp_speculative_gen_draft( llama_decode(ctx, mtp_batch); llama_batch_free(mtp_batch); + llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1); + const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_n_vocab(vocab); diff --git a/include/llama.h b/include/llama.h index 024d53f21c..01e75cea62 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1460,9 +1460,13 @@ extern "C" { LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted); - + + LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx); + LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx); + LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + #ifdef __cplusplus } #endif diff --git a/src/llama-context.cpp b/src/llama-context.cpp index edf5d747f1..8939edabaa 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -3105,6 +3105,20 @@ void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float ctx->draft_input_hidden_state = hidden_state; } +bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx) { + auto * kvd = static_cast(ctx->kv_cache_data); + const auto & last_sinfo = kvd->last_main_model_sinfos; + + if (last_sinfo.empty()) { + LLAMA_LOG_ERROR("%s: The main call sinfo is not available for warmup.\n", __func__); + return false; + } + + kvd->forced_sinfos = &last_sinfo; + return true; +} + + bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) { auto * kvd = static_cast(ctx->kv_cache_data); const auto & last_sinfo = kvd->last_main_model_sinfos; @@ -3126,4 +3140,14 @@ bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_acc void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) { auto * kvd = static_cast(ctx->kv_cache_data); kvd->forced_sinfos = nullptr; -} \ No newline at end of file +} + +void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (memory) { + static_cast(memory.get())->seq_rm(seq_id, p0, p1); + } +} + +void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + ctx->kv_cache_seq_rm(seq_id, p0, p1); +} diff --git a/src/llama-context.h b/src/llama-context.h index 654409cb6c..e15a336938 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -100,6 +100,8 @@ struct llama_context { int32_t il_start, int32_t il_end); + void kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1); + // process a single ubatch with a specific graph type // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 844805d0ce..91cc438dcc 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3520,11 +3520,10 @@ struct server_context { needs_mtp_warmup = true; } } - + if (needs_mtp_warmup) { - if (llama_mtp_prepare_sinfo_for_update(ctx, batch_view.n_tokens)) { + if (llama_mtp_prepare_sinfo_for_warmup(ctx)) { mtp_update_kv_cache(ctx, batch_view, true); - llama_mtp_cancel_sinfo_update(ctx); } else { LOG_ERR("%s: Failed to prepare the MTP symphony for warmup.", __func__);