mtp-batch (fix): warm mtp cache for small batch size
This commit is contained in:
parent
df64508b93
commit
3da7e7f330
|
|
@ -373,15 +373,17 @@ llama_token mtp_speculative_gen_draft(
|
|||
if (!smpl) {
|
||||
return -1;
|
||||
}
|
||||
const float * draft_input_hidden_state = llama_get_embeddings_ith(ctx, last_tok_idx);
|
||||
const float * draft_input_hidden_state = llama_get_embeddings_ith(ctx, -1);
|
||||
llama_set_draft_input_hidden_state(ctx, draft_input_hidden_state);
|
||||
LOG_INF("[DEBUG-DRAFT-STATE] Main model final embd pointer: %p, State being used for draft: %p\n",
|
||||
(void*)llama_get_embeddings(ctx), (void*)draft_input_hidden_state);
|
||||
|
||||
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
|
||||
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
|
||||
|
||||
LOG_INF(
|
||||
"[DEBUG-DRAFT-IN] Generating draft. id_last=%d, n_past=%d, last_tok_idx=%d\n",
|
||||
id_last, n_past, last_tok_idx
|
||||
id_last, n_past, draft_input_hidden_state
|
||||
);
|
||||
|
||||
mtp_batch.update_mtp_kv = false;
|
||||
|
|
@ -411,15 +413,12 @@ llama_token mtp_speculative_gen_draft(
|
|||
}
|
||||
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start, size_t n_tokens) {
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens) {
|
||||
if (tokens.empty()) {
|
||||
tokens.clear();
|
||||
return;
|
||||
}
|
||||
if (n_tokens < 0) {
|
||||
n_tokens = tokens.size();
|
||||
}
|
||||
const size_t n_to_process = std::min((size_t)tokens.size(), n_tokens);
|
||||
|
||||
const size_t n_to_process = tokens.size();
|
||||
|
||||
LOG_DBG(
|
||||
"[MTP BATCHING] mtp_update_kv_cache call for %zu tokens.\n",
|
||||
|
|
@ -438,5 +437,5 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_d
|
|||
llama_decode(ctx, mtp_batch);
|
||||
|
||||
llama_batch_free(mtp_batch);
|
||||
tokens.clear();
|
||||
tokens.clear();
|
||||
}
|
||||
|
|
@ -49,4 +49,4 @@ llama_tokens common_speculative_gen_draft(
|
|||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start = 0, size_t n_tokens = -1);
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens);
|
||||
|
|
|
|||
|
|
@ -1146,7 +1146,14 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// needs to happen before the graph is built
|
||||
n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
if (do_mtp_kv_update) {
|
||||
LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n", ubatch.n_tokens);
|
||||
std::string positions_str;
|
||||
for (int i = 0; i < ubatch.n_tokens; ++i) {
|
||||
positions_str += std::to_string(ubatch.pos[i]) + " ";
|
||||
}
|
||||
LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s\n", positions_str.c_str());
|
||||
}
|
||||
ggml_status status;
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head);
|
||||
|
||||
|
|
|
|||
|
|
@ -3525,13 +3525,18 @@ struct server_context {
|
|||
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
for (auto & slot : slots) {
|
||||
if (slot.has_mtp && slot.n_past == slot.n_prompt_tokens) {
|
||||
SLT_INF(slot, "Prompt processing finished. Warming up MTP KV cache for %d tokens.\n", slot.n_prompt_tokens);
|
||||
slot.mtp_kv_update_batch.clear();
|
||||
|
||||
// for (auto & slot : slots) {
|
||||
// // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
|
||||
// if (slot.has_mtp) {
|
||||
// mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens);
|
||||
// }
|
||||
// }
|
||||
for (int j = 0; j < slot.n_prompt_tokens; ++j) {
|
||||
slot.mtp_kv_update_batch.push_back({ slot.prompt_tokens[j], (llama_pos)j, j });
|
||||
}
|
||||
|
||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
|
||||
}
|
||||
}
|
||||
|
||||
// move the head of the batch forward with the number of tokens we just processed
|
||||
i_next = i + n_tokens;
|
||||
|
|
@ -3697,8 +3702,9 @@ struct server_context {
|
|||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||
|
||||
if (slot.has_mtp) {
|
||||
slot.mtp_kv_update_batch.clear();
|
||||
for (int32_t i = 0; i < ids.size(); ++i) {
|
||||
slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + 1 + i, i });
|
||||
slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + i, i });
|
||||
}
|
||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue