diff --git a/common/speculative.cpp b/common/speculative.cpp index 950a9a54bc..503da98194 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -373,7 +373,7 @@ llama_token mtp_speculative_gen_draft( if (!smpl) { return -1; } - const float * draft_input_hidden_state = llama_get_embeddings_ith(ctx, -1); + const float * draft_input_hidden_state = llama_get_embeddings(ctx); llama_set_draft_input_hidden_state(ctx, draft_input_hidden_state); LOG_INF("[DEBUG-DRAFT-STATE] Main model final embd pointer: %p, State being used for draft: %p\n", (void*)llama_get_embeddings(ctx), (void*)draft_input_hidden_state); @@ -413,17 +413,24 @@ llama_token mtp_speculative_gen_draft( } -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens) { +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, const char* tag) { if (tokens.empty()) { return; } const size_t n_to_process = tokens.size(); + std::string details_str; + for (size_t i = 0; i < std::min((size_t)5, n_to_process); ++i) { + details_str += " {id: " + std::to_string(tokens[i].id) + ", pos: " + std::to_string(tokens[i].n_past) + "}"; + } + LOG_INF("[MTP-UPDATE|%s] Updating %zu tokens. Details:%s ...\n", tag, n_to_process, details_str.c_str()); - LOG_DBG( - "[MTP BATCHING] mtp_update_kv_cache call for %zu tokens.\n", - n_to_process - ); + // LOG_INF("[DEBUG-CHUNK] Warming up MTP model chunk. Batch size: %zu\n", n_to_process); + // std::string positions_str; + // for (size_t i = 0; i < std::min((size_t)5, n_to_process); ++i) { + // positions_str += std::to_string(tokens[i].n_past) + " "; + // } + // LOG_INF("[DEBUG-CHUNK] MTP warm-up positions: %s...\n", positions_str.c_str()); llama_batch mtp_batch = llama_batch_init(n_to_process, 0, 1); for (size_t i = 0; i < n_to_process; ++i) { diff --git a/common/speculative.h b/common/speculative.h index 827600c33d..c60bd97ac3 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -49,4 +49,4 @@ llama_tokens common_speculative_gen_draft( const llama_tokens & prompt, llama_token id_last); -void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens); +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, const char* tag); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5427f29eb7..070c1b738f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -778,13 +778,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); const float * source_hidden_state = nullptr; - if (do_mtp_kv_update) { - // Cache warming uses the entire embeddings buffer - source_hidden_state = this->embd; - } else { - // Draft generation uses the specific state - source_hidden_state = this->draft_input_hidden_state; - } + source_hidden_state = this->draft_input_hidden_state; if (source_hidden_state != nullptr && hidden_states_input != nullptr) { ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); @@ -1149,14 +1143,13 @@ int llama_context::decode(const llama_batch & batch_inp) { if (do_mtp_kv_update) { LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] MTP KV Update ubatch: n_tokens=%d\n", ubatch.n_tokens); std::string positions_str; - for (int i = 0; i < ubatch.n_tokens; ++i) { + for (int i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) { positions_str += std::to_string(ubatch.pos[i]) + " "; } - LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s\n", positions_str.c_str()); + LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s...\n", positions_str.c_str()); } ggml_status status; const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head); - if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache llama_pos pos_min[LLAMA_MAX_SEQ]; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index ddd7b6afa8..aba5859acf 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3387,14 +3387,8 @@ struct server_context { slot.n_prompt_tokens_processed += n_pos; } - const size_t n_to_log = slot.mtp_kv_update_batch.size(); - if (n_to_log > 0) { - SLT_INF(slot, - "DEBUG-KV-REQ Cache Warm-up: Requesting KV update for %zu tokens. Positions: %d ... %d\n", - n_to_log, - slot.mtp_kv_update_batch.front().n_past, - slot.mtp_kv_update_batch.back().n_past - ); + if (slot.has_mtp) { + slot.mtp_kv_update_batch.clear(); } // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { @@ -3484,6 +3478,7 @@ struct server_context { batch.seq_id + i, batch.logits + i, }; + LOG_INF("\n[DEBUG-CHUNK] Processing main model chunk. Batch size: %d\n", n_tokens); const int ret = llama_decode(ctx, batch_view); @@ -3525,16 +3520,18 @@ struct server_context { continue; // continue loop of n_batch } + + // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation + // Aquece o cache MTP para os pedaços do prompt que acabaram de ser processados. + // Esta lógica SÓ deve ser executada durante o processamento do prompt. for (auto & slot : slots) { - if (slot.has_mtp && slot.n_past == slot.n_prompt_tokens) { - SLT_INF(slot, "Prompt processing finished. Warming up MTP KV cache for %d tokens.\n", slot.n_prompt_tokens); - slot.mtp_kv_update_batch.clear(); - - for (int j = 0; j < slot.n_prompt_tokens; ++j) { - slot.mtp_kv_update_batch.push_back({ slot.prompt_tokens[j], (llama_pos)j, j }); - } - - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + if (slot.state == SLOT_STATE_PROCESSING_PROMPT && slot.has_mtp && !slot.mtp_kv_update_batch.empty()) { + SLT_INF(slot, "DEBUG-KV-REQ: Warming up MTP cache for prompt chunk of size %zu. Positions: %d ... %d\n", + slot.mtp_kv_update_batch.size(), + slot.mtp_kv_update_batch.front().n_past, + slot.mtp_kv_update_batch.back().n_past + ); + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, "PROMPT_WARMUP"); } } @@ -3581,11 +3578,6 @@ struct server_context { common_sampler_accept(slot.smpl, id, true); - // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - //if (slot.has_mtp) { - // mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); - //} - slot.n_decoded += 1; const int64_t t_current = ggml_time_us(); @@ -3670,11 +3662,6 @@ struct server_context { draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); } - //llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); - //llama_tokens draft; - //draft.reserve(1); - //draft.push_back(draft_id); - // ignore small drafts if (slot.params.speculative.n_min > (int)draft.size()) { SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min); @@ -3706,7 +3693,7 @@ struct server_context { for (int32_t i = 0; i < ids.size(); ++i) { slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + i, i }); } - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, "GEN_ACCEPTED"); } slot.n_past += ids.size();