From 3da7e7f3309dbb576538850c92c1cbf8fdc6d6ee Mon Sep 17 00:00:00 2001 From: samuel Date: Tue, 23 Sep 2025 22:45:11 -0300 Subject: [PATCH] mtp-batch (fix): warm mtp cache for small batch size --- common/speculative.cpp | 17 ++++++++--------- common/speculative.h | 2 +- src/llama-context.cpp | 9 ++++++++- tools/server/server.cpp | 20 +++++++++++++------- 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 1604dbd48a..950a9a54bc 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -373,15 +373,17 @@ llama_token mtp_speculative_gen_draft( if (!smpl) { return -1; } - const float * draft_input_hidden_state = llama_get_embeddings_ith(ctx, last_tok_idx); + const float * draft_input_hidden_state = llama_get_embeddings_ith(ctx, -1); llama_set_draft_input_hidden_state(ctx, draft_input_hidden_state); + LOG_INF("[DEBUG-DRAFT-STATE] Main model final embd pointer: %p, State being used for draft: %p\n", + (void*)llama_get_embeddings(ctx), (void*)draft_input_hidden_state); llama_batch mtp_batch = llama_batch_init(1, 0, 1); common_batch_add(mtp_batch, id_last, n_past, {0}, true); LOG_INF( "[DEBUG-DRAFT-IN] Generating draft. id_last=%d, n_past=%d, last_tok_idx=%d\n", - id_last, n_past, last_tok_idx + id_last, n_past, draft_input_hidden_state ); mtp_batch.update_mtp_kv = false; @@ -411,15 +413,12 @@ llama_token mtp_speculative_gen_draft( } -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start, size_t n_tokens) { +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens) { if (tokens.empty()) { - tokens.clear(); return; } - if (n_tokens < 0) { - n_tokens = tokens.size(); - } - const size_t n_to_process = std::min((size_t)tokens.size(), n_tokens); + + const size_t n_to_process = tokens.size(); LOG_DBG( "[MTP BATCHING] mtp_update_kv_cache call for %zu tokens.\n", @@ -438,5 +437,5 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start = 0, size_t n_tokens = -1); +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 9d77ce3079..5427f29eb7 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1146,7 +1146,14 @@ int llama_context::decode(const llama_batch & batch_inp) { // needs to happen before the graph is built n_outputs = n_outputs_new; } - + if (do_mtp_kv_update) { + LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n", ubatch.n_tokens); + std::string positions_str; + for (int i = 0; i < ubatch.n_tokens; ++i) { + positions_str += std::to_string(ubatch.pos[i]) + " "; + } + LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s\n", positions_str.c_str()); + } ggml_status status; const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 7070a56159..ddd7b6afa8 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3525,13 +3525,18 @@ struct server_context { continue; // continue loop of n_batch } + for (auto & slot : slots) { + if (slot.has_mtp && slot.n_past == slot.n_prompt_tokens) { + SLT_INF(slot, "Prompt processing finished. Warming up MTP KV cache for %d tokens.\n", slot.n_prompt_tokens); + slot.mtp_kv_update_batch.clear(); - // for (auto & slot : slots) { - // // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - // if (slot.has_mtp) { - // mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens); - // } - // } + for (int j = 0; j < slot.n_prompt_tokens; ++j) { + slot.mtp_kv_update_batch.push_back({ slot.prompt_tokens[j], (llama_pos)j, j }); + } + + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + } + } // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; @@ -3697,8 +3702,9 @@ struct server_context { const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); if (slot.has_mtp) { + slot.mtp_kv_update_batch.clear(); for (int32_t i = 0; i < ids.size(); ++i) { - slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + 1 + i, i }); + slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + i, i }); } mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); }