From 5e1d719beffccf8c22784c24b52ff6f5ab56b9ff Mon Sep 17 00:00:00 2001 From: samuel Date: Thu, 9 Oct 2025 15:21:23 -0300 Subject: [PATCH] mtp-batch (feat): Create and manage sinfo for MTP --- common/speculative.cpp | 27 +++++- common/speculative.h | 7 ++ include/llama.h | 6 +- src/llama-context.cpp | 69 ++++++++++++++-- src/llama-context.h | 10 +++ src/llama-kv-cache-unified.cpp | 147 ++++++++++++++++++++++----------- src/llama-kv-cache-unified.h | 15 +++- tools/server/server.cpp | 24 +++--- 8 files changed, 236 insertions(+), 69 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 2e0b91a4e2..f71982f9e4 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -418,10 +418,35 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b for (int i = 0; i < mtp_batch.n_tokens; ++i) { mtp_batch.logits[i] = false; } - llama_decode(ctx, mtp_batch); } +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +) { + if (ids.empty()) { + return; + } + + if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) { + return; + } + + llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); + for (size_t i = 0; i < ids.size(); ++i) { + common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, false); + } + + mtp_update_kv_cache(ctx, accepted_batch, false); + + llama_mtp_cancel_sinfo_update(ctx); + + llama_batch_free(accepted_batch); +} + // Debug function - It will be removed later double calculate_vector_sum_double(const float* vec, size_t size) { if (!vec) { diff --git a/common/speculative.h b/common/speculative.h index e121e8ed14..d361e69d07 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -51,4 +51,11 @@ llama_tokens common_speculative_gen_draft( void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +); + double calculate_vector_sum_double(const float* vec, size_t size); \ No newline at end of file diff --git a/include/llama.h b/include/llama.h index e6d6aadf7e..024d53f21c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1457,7 +1457,11 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + + LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted); + + LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx); #ifdef __cplusplus } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 81c7a48d0e..edf5d747f1 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -18,6 +18,11 @@ // // llama_context // +struct llama_context_kv_cache_data { + llama_kv_cache_unified::slot_info_vec_t last_main_model_sinfos; + llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force; + const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr; +}; llama_context::llama_context( const llama_model & model, @@ -106,6 +111,8 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + kv_cache_data = new llama_context_kv_cache_data(); + { const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows; @@ -371,6 +378,7 @@ llama_context::llama_context( llama_context::~llama_context() { ggml_opt_free(opt_ctx); + delete static_cast(kv_cache_data); } void llama_context::synchronize() { @@ -1017,6 +1025,8 @@ int llama_context::encode(const llama_batch & batch_inp) { int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + + auto * kvd = static_cast(kv_cache_data); LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n", batch_inp.update_mtp_kv ? "true" : "false", batch_inp.use_mtp_head ? "true" : "false" @@ -1076,10 +1086,31 @@ int llama_context::decode(const llama_batch & batch_inp) { // handle any pending defrags/shifts kv_self_update(false); - llama_memory_context_ptr mctx; + std::unique_ptr mctx; while (true) { - mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + if (cparams.warmup) { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + } else { + if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) { + LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n"); + + mctx = static_cast(memory.get())->init_batch_with_sinfos( + *balloc, cparams.n_ubatch, *kvd->forced_sinfos, true + ); + } else { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + + if (!batch_inp.use_mtp_head && !batch_inp.update_mtp_kv) { + if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) { + kvd->last_main_model_sinfos = static_cast(mctx.get())->get_sinfos(); + } else { + kvd->last_main_model_sinfos.clear(); + } + } + } + } + if (!mctx) { return -2; } @@ -1091,29 +1122,28 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_MEMORY_STATUS_NO_UPDATE: { LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); - return -2; } case LLAMA_MEMORY_STATUS_FAILED_PREPARE: { + // if (use_last_main_model_sinfos) { + // LLAMA_LOG_ERROR("%s: Mismatch between ubatches and sinfos during reuse.\n", __func__); + // return -1; + // } + if (!did_optimize) { did_optimize = true; - if (kv_self_update(true)) { LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); - continue; } } - LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens()); - return 1; } case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: { LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens()); - return -2; } } @@ -3073,4 +3103,27 @@ void llama_opt_epoch( void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) { ctx->draft_input_hidden_state = hidden_state; +} + +bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) { + auto * kvd = static_cast(ctx->kv_cache_data); + const auto & last_sinfo = kvd->last_main_model_sinfos; + + if (last_sinfo.empty() || last_sinfo[0].idxs.empty()) { + LLAMA_LOG_ERROR("%s: The sinfo for the last main call is not available.", __func__); + return false; + } + + kvd->resized_sinfo_for_force = last_sinfo; + + kvd->resized_sinfo_for_force[0].idxs[0].resize(n_accepted); + + kvd->forced_sinfos = &kvd->resized_sinfo_for_force; + + return true; +} + +void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) { + auto * kvd = static_cast(ctx->kv_cache_data); + kvd->forced_sinfos = nullptr; } \ No newline at end of file diff --git a/src/llama-context.h b/src/llama-context.h index aa6ced7947..654409cb6c 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -20,6 +20,8 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; +struct llama_context_kv_cache_data; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -27,6 +29,11 @@ struct llama_context { llama_context_params params); ~llama_context(); + + llama_context(const llama_context &) = delete; + llama_context & operator=(const llama_context &) = delete; + llama_context(llama_context &&) = delete; + llama_context & operator=(llama_context &&) = delete; void synchronize(); @@ -211,6 +218,9 @@ public: std::unique_ptr mtp_memory_batch(const llama_batch& batch_inp); + // For MTP KV cache cell reuse + void * kv_cache_data; + private: llm_graph_params graph_params( llm_graph_result * res, diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 53466264cd..787fb8d9a5 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -508,6 +508,34 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } +llama_memory_context_ptr llama_kv_cache_unified::init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update) { + + if (sinfos.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + balloc.split_reset(); + std::vector ubatches; + while (true) { + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); + if (ubatch.n_tokens == 0) { + break; + } + ubatches.push_back(std::move(ubatch)); + } + + if (ubatches.size() != sinfos.size()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, sinfos, std::move(ubatches), is_inplace_update); +} + llama_memory_context_ptr llama_kv_cache_unified::init_full() { return std::make_unique(this); } @@ -738,6 +766,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { + LLAMA_LOG_WARN("%s: Entering find_slot for ubatch of %d tokens.\n", __func__, ubatch.n_tokens); if (debug > 0) { const auto & cells = v_cells[seq_to_stream[1]]; @@ -928,73 +957,96 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } assert(res.s1 >= res.s0); + if (!res.empty()) { + std::string idxs_str; + for (const auto& vec : res.idxs) { + if (!vec.empty()) { + if (vec.size() > 8) { + idxs_str += " [" + std::to_string(vec.front()) + "..." + std::to_string(vec.back()) + " (" + std::to_string(vec.size()) + " cells)]"; + } else { + idxs_str += " ["; + for(size_t i = 0; i < vec.size(); ++i) { + idxs_str += std::to_string(vec[i]) + (i == vec.size() - 1 ? "" : ", "); + } + idxs_str += "]"; + } + } + } + LLAMA_LOG_WARN("%s: find_slot SUCCEEDED for ubatch of %d tokens. Idxs:%s\n", __func__, ubatch.n_tokens, idxs_str.c_str()); + } else { + LLAMA_LOG_ERROR("%s: find_slot FAILED to allocate cells for ubatch of %d tokens.\n", __func__, ubatch.n_tokens); + } return res; } -void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { - // keep track of the max sequence position that we would overwrite with this ubatch - // for non-SWA cache, this would be always empty - llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - seq_pos_max_rm[s] = -1; - } +void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) { + if (!is_inplace_update) { + // keep track of the max sequence position that we would overwrite with this ubatch + // for non-SWA cache, this would be always empty + llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + seq_pos_max_rm[s] = -1; + } - assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); + assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { - for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { - const uint32_t i = s*sinfo.size() + ii; + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - auto & cells = v_cells[sinfo.strm[s]]; + auto & cells = v_cells[sinfo.strm[s]]; - const auto idx = sinfo.idxs[s][ii]; + const auto idx = sinfo.idxs[s][ii]; - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + if (!is_inplace_update) { + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - cells.rm(idx); - } + cells.rm(idx); + } + } - cells.pos_set(idx, ubatch.pos[i]); + cells.pos_set(idx, ubatch.pos[i]); - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } - } - // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence - // will be present in the cache. so we have to purge any position which is less than those we would overwrite - // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - if (seq_pos_max_rm[s] == -1) { - continue; + // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence + // will be present in the cache. so we have to purge any position which is less than those we would overwrite + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq_pos_max_rm[s] == -1) { + continue; + } + + GGML_ASSERT(s < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[s]]; + + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { + LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", + __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); + + seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + } } - GGML_ASSERT(s < seq_to_stream.size()); + // move the head at the end of the slot + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & head = v_heads[sinfo.strm[s]]; - auto & cells = v_cells[seq_to_stream[s]]; - - if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { - LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", - __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); - - seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + head = sinfo.idxs[s].back() + 1; } } - - // move the head at the end of the slot - for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { - auto & head = v_heads[sinfo.strm[s]]; - - head = sinfo.idxs[s].back() + 1; - } } bool llama_kv_cache_unified::get_can_shift() const { @@ -2290,7 +2342,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_kv_cache_unified::slot_info_vec_t sinfos, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) { + std::vector ubatches, + bool is_inplace_update) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)), is_inplace_update(is_inplace_update) { } llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; @@ -2315,7 +2368,7 @@ bool llama_kv_cache_unified_context::apply() { return true; } - kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); + kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur], is_inplace_update); n_kv = kv->get_n_kv(); diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index c02607c2d0..f64f7faa5c 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -116,6 +116,12 @@ public: llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; + + llama_memory_context_ptr init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update); llama_memory_context_ptr init_full() override; @@ -181,7 +187,7 @@ public: slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] - void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); + void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update = false); // // input API @@ -321,7 +327,8 @@ public: llama_kv_cache_unified_context( llama_kv_cache_unified * kv, slot_info_vec_t sinfos, - std::vector ubatches); + std::vector ubatches, + bool is_inplace_update = false); virtual ~llama_kv_cache_unified_context(); @@ -365,6 +372,8 @@ public: void set_sinfos(slot_info_vec_t new_sinfos); + const slot_info_vec_t & get_sinfos() const { return sinfos; } + private: llama_memory_status status; @@ -399,4 +408,6 @@ private: // a heuristic, to avoid attending the full cache if it is not yet utilized // as the cache gets filled, the benefit from this heuristic disappears int32_t n_kv; + + bool is_inplace_update = false; }; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 3399e16823..844805d0ce 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3522,8 +3522,15 @@ struct server_context { } if (needs_mtp_warmup) { - mtp_update_kv_cache(ctx, batch_view, true); + if (llama_mtp_prepare_sinfo_for_update(ctx, batch_view.n_tokens)) { + mtp_update_kv_cache(ctx, batch_view, true); + + llama_mtp_cancel_sinfo_update(ctx); + } else { + LOG_ERR("%s: Failed to prepare the MTP symphony for warmup.", __func__); + } } + // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; @@ -3696,16 +3703,13 @@ struct server_context { SLT_INF(slot, "[VERIFY] Checksum after draft gen (should be unchanged): %e\n", checksum_after_draft); if (!ids.empty()) { - llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); - - for (size_t i = 0; i < ids.size(); ++i) { - common_batch_add(accepted_batch, ids[i], slot.n_past + i, { slot.id }, false); - } - - mtp_update_kv_cache(ctx, accepted_batch, false); - - llama_batch_free(accepted_batch); + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); + } else { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0)); } + + mtp_accept_tokens(ctx, ids, slot.n_past, slot.id); + const float* embd_after_update_ptr = llama_get_embeddings(ctx); double checksum_after_update = calculate_vector_sum_double(embd_after_update_ptr, golden_buffer_size_in_floats); SLT_INF(slot, "[VERIFY] Checksum after MTP update (should be unchanged): %e\n", checksum_after_update);