fixed mtp kv cache update step in cases where prompt size > n_batch and n_ubatch
This commit is contained in:
parent
98bc0c6bf2
commit
9fab53e438
|
|
@ -423,11 +423,18 @@ llama_token mtp_speculative_gen_draft(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens) {
|
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start, size_t n_tokens) {
|
||||||
mtp_kv_update_data token;
|
mtp_kv_update_data token;
|
||||||
for (int i = 0; i < tokens.size(); ++i) {
|
|
||||||
|
if (n_tokens < 0) {
|
||||||
|
n_tokens = tokens.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) {
|
||||||
token = tokens[i];
|
token = tokens[i];
|
||||||
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx);
|
//fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start));
|
||||||
|
|
||||||
|
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start);
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens.clear();
|
tokens.clear();
|
||||||
|
|
|
||||||
|
|
@ -49,4 +49,4 @@ llama_tokens common_speculative_gen_draft(
|
||||||
const llama_tokens & prompt,
|
const llama_tokens & prompt,
|
||||||
llama_token id_last);
|
llama_token id_last);
|
||||||
|
|
||||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens);
|
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start = 0, size_t n_tokens = -1);
|
||||||
|
|
|
||||||
|
|
@ -1405,9 +1405,14 @@ struct server_slot {
|
||||||
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
|
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
|
||||||
// also we cannot split if the pooling would require any past tokens
|
// also we cannot split if the pooling would require any past tokens
|
||||||
bool can_split() const {
|
bool can_split() const {
|
||||||
|
//fprintf(stderr, "need_embd() %d\n", need_embd());
|
||||||
|
//fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr);
|
||||||
|
//fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
|
||||||
|
|
||||||
return
|
return
|
||||||
!need_embd() ||
|
!need_embd() ||
|
||||||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
|
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) ||
|
||||||
|
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch?
|
||||||
}
|
}
|
||||||
|
|
||||||
bool can_batch_with(server_slot & other_slot) const {
|
bool can_batch_with(server_slot & other_slot) const {
|
||||||
|
|
@ -3508,6 +3513,13 @@ struct server_context {
|
||||||
continue; // continue loop of n_batch
|
continue; // continue loop of n_batch
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto & slot : slots) {
|
||||||
|
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
|
||||||
|
if (slot.has_mtp) {
|
||||||
|
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// move the head of the batch forward with the number of tokens we just processed
|
// move the head of the batch forward with the number of tokens we just processed
|
||||||
i_next = i + n_tokens;
|
i_next = i + n_tokens;
|
||||||
|
|
||||||
|
|
@ -3552,9 +3564,9 @@ struct server_context {
|
||||||
common_sampler_accept(slot.smpl, id, true);
|
common_sampler_accept(slot.smpl, id, true);
|
||||||
|
|
||||||
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
|
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
|
||||||
if (slot.has_mtp) {
|
//if (slot.has_mtp) {
|
||||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
|
// mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
|
||||||
}
|
//}
|
||||||
|
|
||||||
slot.n_decoded += 1;
|
slot.n_decoded += 1;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue