From 9fab53e4388c20aef497efd82e86dcb99ca58064 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Tue, 2 Sep 2025 17:14:09 -0400 Subject: [PATCH] fixed mtp kv cache update step in cases where prompt size > n_batch and n_ubatch --- common/speculative.cpp | 13 ++++++++++--- common/speculative.h | 2 +- tools/server/server.cpp | 20 ++++++++++++++++---- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index edeffe2d8e..c1d9149ea1 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -423,11 +423,18 @@ llama_token mtp_speculative_gen_draft( } -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens) { +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start, size_t n_tokens) { mtp_kv_update_data token; - for (int i = 0; i < tokens.size(); ++i) { + + if (n_tokens < 0) { + n_tokens = tokens.size(); + } + + for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) { token = tokens[i]; - mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx); + //fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start)); + + mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start); } tokens.clear(); diff --git a/common/speculative.h b/common/speculative.h index 786f3ad1e8..bb29c07bb6 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -49,4 +49,4 @@ llama_tokens common_speculative_gen_draft( const llama_tokens & prompt, llama_token id_last); -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens); +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start = 0, size_t n_tokens = -1); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 1191564dd2..34053cd040 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1405,9 +1405,14 @@ struct server_slot { // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens bool can_split() const { + //fprintf(stderr, "need_embd() %d\n", need_embd()); + //fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr); + //fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + return !need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch? } bool can_batch_with(server_slot & other_slot) const { @@ -3508,6 +3513,13 @@ struct server_context { continue; // continue loop of n_batch } + for (auto & slot : slots) { + // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation + if (slot.has_mtp) { + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens); + } + } + // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; @@ -3552,9 +3564,9 @@ struct server_context { common_sampler_accept(slot.smpl, id, true); // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - if (slot.has_mtp) { - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); - } + //if (slot.has_mtp) { + // mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + //} slot.n_decoded += 1;