From febd8235d27fe9174ee4b54ea7a10e630939fee0 Mon Sep 17 00:00:00 2001 From: samuel Date: Sun, 5 Oct 2025 14:43:40 -0300 Subject: [PATCH] mtp-batch (wip): fix how to warmup kv cache for MTP --- common/speculative.cpp | 29 ++++++++--------------------- common/speculative.h | 3 +-- src/llama-model.cpp | 17 +++++++++++++++-- tools/server/server.cpp | 40 ++++++++++++++++++---------------------- 4 files changed, 42 insertions(+), 47 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 7a17b8d965..2e0b91a4e2 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -403,36 +403,23 @@ llama_token mtp_speculative_gen_draft( } -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, - bool is_prompt_warmup) { - - if (tokens.empty()) { +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) { + if (batch.n_tokens == 0) { return; } - const size_t n_to_process = tokens.size(); - std::string details_str; - for (size_t i = 0; i < std::min((size_t)5, n_to_process); ++i) { - details_str += " {id: " + std::to_string(tokens[i].id) + ", pos: " + std::to_string(tokens[i].n_past) + "}"; - } - LOG_INF("[MTP-UPDATE|%s] Updating %zu tokens. Details:%s ...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", n_to_process, details_str.c_str()); - - llama_batch mtp_batch = llama_batch_init(n_to_process, 0, 1); - - for (size_t i = 0; i < n_to_process; ++i) { - const mtp_kv_update_data& token_data = tokens[i]; - // Check seq_id {0}, it may be a problem with multiple sequences. - common_batch_add(mtp_batch, token_data.id, token_data.n_past, {0}, false); - } + LOG_INF("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens); + llama_batch mtp_batch = batch; mtp_batch.update_mtp_kv = true; mtp_batch.use_mtp_head = true; mtp_batch.is_mtp_prompt_warmup = is_prompt_warmup; - llama_decode(ctx, mtp_batch); + for (int i = 0; i < mtp_batch.n_tokens; ++i) { + mtp_batch.logits[i] = false; + } - llama_batch_free(mtp_batch); - tokens.clear(); + llama_decode(ctx, mtp_batch); } // Debug function - It will be removed later diff --git a/common/speculative.h b/common/speculative.h index 11c0d4553e..e121e8ed14 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -49,7 +49,6 @@ llama_tokens common_speculative_gen_draft( const llama_tokens & prompt, llama_token id_last); -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, - bool is_prompt_warmup); +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); double calculate_vector_sum_double(const float* vec, size_t size); \ No newline at end of file diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1f00bb7dd7..6ca53a80cd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13788,6 +13788,13 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + LLAMA_LOG_WARN( + "[GRAPH_BUILD] Building graph. Path: %s, MTP_Update: %s, UBatch_Tokens: %d, First_Pos: %d\n", + params.use_mtp_head ? "MTP" : "MAIN", + params.update_mtp_kv ? "true" : "false", + n_tokens, + n_tokens > 0 ? ubatch.pos[0] : -1 + ); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13906,7 +13913,10 @@ struct llm_build_glm4_moe : public llm_graph_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - + if (ubatch.n_tokens > 0) { + LLAMA_LOG_WARN("[KV_WRITE] path=MAIN, layer=%d, n_tokens=%d, pos_start=%d, pos_end=%d\n", + il, ubatch.n_tokens, ubatch.pos[0], ubatch.pos[ubatch.n_tokens-1]); + } cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); @@ -14060,7 +14070,10 @@ private: // LLAMA_LOG_WARN(" - Qcur shape: [%d, %d, %d]\n", Qcur->ne[0], Qcur->ne[1], Qcur->ne[2]); // LLAMA_LOG_WARN(" - Kcur shape: [%d, %d, %d]\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2]); // LLAMA_LOG_WARN(" - Vcur shape: [%d, %d, %d]\n", Vcur->ne[0], Vcur->ne[1], Vcur->ne[2]); - + if (ubatch.n_tokens > 0) { + LLAMA_LOG_WARN("[KV_WRITE] path=MTP, layer=%d, n_tokens=%d, pos_start=%d, pos_end=%d\n", + il, ubatch.n_tokens, ubatch.pos[0], ubatch.pos[ubatch.n_tokens-1]); + } cur = build_attn(inp_attn, mtp_layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 3d025b2120..3399e16823 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1296,7 +1296,6 @@ struct server_slot { common_speculative * spec = nullptr; bool has_mtp = false; - std::vector mtp_kv_update_batch; int32_t last_tok_idx = -1; std::vector lora; @@ -3387,9 +3386,6 @@ struct server_context { slot.n_prompt_tokens_processed += n_pos; } - if (slot.has_mtp) { - slot.mtp_kv_update_batch.clear(); - } // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { // get next token to process @@ -3401,9 +3397,6 @@ struct server_context { // embedding requires all tokens in the batch to be output const bool need_embd = server_task_type_need_embd(slot.task_type); - if (slot.has_mtp) { - slot.mtp_kv_update_batch.push_back({ cur_tok, slot.n_past, batch.n_tokens }); - } common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); slot.cache_tokens.push_back(cur_tok); @@ -3520,19 +3513,17 @@ struct server_context { continue; // continue loop of n_batch } - - // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - for (auto & slot : slots) { - if (slot.state == SLOT_STATE_PROCESSING_PROMPT && slot.has_mtp && !slot.mtp_kv_update_batch.empty()) { - SLT_INF(slot, "DEBUG-KV-REQ: Warming up MTP cache for prompt chunk of size %zu. Positions: %d ... %d\n", - slot.mtp_kv_update_batch.size(), - slot.mtp_kv_update_batch.front().n_past, - slot.mtp_kv_update_batch.back().n_past - ); - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, true); + + bool needs_mtp_warmup = false; + if (slot_batched && slot_batched->has_mtp) { + if (slot_batched->state == SLOT_STATE_PROCESSING_PROMPT || slot_batched->state == SLOT_STATE_DONE_PROMPT) { + needs_mtp_warmup = true; } } + if (needs_mtp_warmup) { + mtp_update_kv_cache(ctx, batch_view, true); + } // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; @@ -3704,12 +3695,17 @@ struct server_context { double checksum_after_draft = calculate_vector_sum_double(embd_after_draft_ptr, golden_buffer_size_in_floats); SLT_INF(slot, "[VERIFY] Checksum after draft gen (should be unchanged): %e\n", checksum_after_draft); - slot.mtp_kv_update_batch.clear(); - for (int32_t i = 0; i < ids.size(); ++i) { - slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + i, i }); - } - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, false); + if (!ids.empty()) { + llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); + for (size_t i = 0; i < ids.size(); ++i) { + common_batch_add(accepted_batch, ids[i], slot.n_past + i, { slot.id }, false); + } + + mtp_update_kv_cache(ctx, accepted_batch, false); + + llama_batch_free(accepted_batch); + } const float* embd_after_update_ptr = llama_get_embeddings(ctx); double checksum_after_update = calculate_vector_sum_double(embd_after_update_ptr, golden_buffer_size_in_floats); SLT_INF(slot, "[VERIFY] Checksum after MTP update (should be unchanged): %e\n", checksum_after_update);