fixed mtp kv cache update step in cases where prompt size > n_batch and n_ubatch
This commit is contained in:
parent
98bc0c6bf2
commit
9fab53e438
|
|
@ -423,11 +423,18 @@ llama_token mtp_speculative_gen_draft(
|
|||
}
|
||||
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens) {
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start, size_t n_tokens) {
|
||||
mtp_kv_update_data token;
|
||||
for (int i = 0; i < tokens.size(); ++i) {
|
||||
|
||||
if (n_tokens < 0) {
|
||||
n_tokens = tokens.size();
|
||||
}
|
||||
|
||||
for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) {
|
||||
token = tokens[i];
|
||||
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx);
|
||||
//fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start));
|
||||
|
||||
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start);
|
||||
}
|
||||
|
||||
tokens.clear();
|
||||
|
|
|
|||
|
|
@ -49,4 +49,4 @@ llama_tokens common_speculative_gen_draft(
|
|||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens);
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start = 0, size_t n_tokens = -1);
|
||||
|
|
|
|||
|
|
@ -1405,9 +1405,14 @@ struct server_slot {
|
|||
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
|
||||
// also we cannot split if the pooling would require any past tokens
|
||||
bool can_split() const {
|
||||
//fprintf(stderr, "need_embd() %d\n", need_embd());
|
||||
//fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr);
|
||||
//fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
|
||||
|
||||
return
|
||||
!need_embd() ||
|
||||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
|
||||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) ||
|
||||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch?
|
||||
}
|
||||
|
||||
bool can_batch_with(server_slot & other_slot) const {
|
||||
|
|
@ -3508,6 +3513,13 @@ struct server_context {
|
|||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
|
||||
if (slot.has_mtp) {
|
||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
// move the head of the batch forward with the number of tokens we just processed
|
||||
i_next = i + n_tokens;
|
||||
|
||||
|
|
@ -3552,9 +3564,9 @@ struct server_context {
|
|||
common_sampler_accept(slot.smpl, id, true);
|
||||
|
||||
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
|
||||
if (slot.has_mtp) {
|
||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
|
||||
}
|
||||
//if (slot.has_mtp) {
|
||||
// mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
|
||||
//}
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue