fixed mtp kv cache update step in cases where prompt size > n_batch and n_ubatch

This commit is contained in:
Aaron Lee 2025-09-02 17:14:09 -04:00
parent 98bc0c6bf2
commit 9fab53e438
3 changed files with 27 additions and 8 deletions

View File

@ -423,11 +423,18 @@ llama_token mtp_speculative_gen_draft(
}
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens) {
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start, size_t n_tokens) {
mtp_kv_update_data token;
for (int i = 0; i < tokens.size(); ++i) {
if (n_tokens < 0) {
n_tokens = tokens.size();
}
for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) {
token = tokens[i];
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx);
//fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start));
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start);
}
tokens.clear();

View File

@ -49,4 +49,4 @@ llama_tokens common_speculative_gen_draft(
const llama_tokens & prompt,
llama_token id_last);
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens);
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start = 0, size_t n_tokens = -1);

View File

@ -1405,9 +1405,14 @@ struct server_slot {
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
bool can_split() const {
//fprintf(stderr, "need_embd() %d\n", need_embd());
//fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr);
//fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
return
!need_embd() ||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) ||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch?
}
bool can_batch_with(server_slot & other_slot) const {
@ -3508,6 +3513,13 @@ struct server_context {
continue; // continue loop of n_batch
}
for (auto & slot : slots) {
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
if (slot.has_mtp) {
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens);
}
}
// move the head of the batch forward with the number of tokens we just processed
i_next = i + n_tokens;
@ -3552,9 +3564,9 @@ struct server_context {
common_sampler_accept(slot.smpl, id, true);
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
if (slot.has_mtp) {
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
}
//if (slot.has_mtp) {
// mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
//}
slot.n_decoded += 1;