From a99709d0c1401d0b447dce1bd0101fb56390f50e Mon Sep 17 00:00:00 2001 From: samuel Date: Fri, 10 Oct 2025 17:24:34 -0300 Subject: [PATCH] mtp-batch(refactor): Extract decode context and MTP input logic into helper methods --- src/llama-context.cpp | 119 +++++++++++++++++++++++++++--------------- src/llama-context.h | 8 +++ 2 files changed, 84 insertions(+), 43 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f22a398048..4bdbee951d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -794,28 +794,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation - const char * target_tensor_name = "result_embd_pooled"; - ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); - - const float * source_hidden_state = nullptr; - if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { - source_hidden_state = this->embd; - } else { - source_hidden_state = this->draft_input_hidden_state; - } - - if (source_hidden_state != nullptr && hidden_states_input != nullptr) { - const size_t n_embd = this->model.hparams.n_embd; - const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1; - double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd); - const char * op_type = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) ? "MTP_UPDATE" : "DRAFT_GEN"; - - LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum); - - ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); - } else { - LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n", - __func__, target_tensor_name); + if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) { ret = GGML_STATUS_FAILED; return nullptr; } @@ -1089,27 +1068,7 @@ int llama_context::decode(const llama_batch & batch_inp) { std::unique_ptr mctx; while (true) { - if (cparams.warmup) { - mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); - } else { - if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) { - LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n"); - - mctx = static_cast(memory.get())->init_batch_with_sinfos( - *balloc, cparams.n_ubatch, *kvd->forced_sinfos, true - ); - } else { - mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); - - if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { - if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) { - kvd->last_main_model_sinfos = static_cast(mctx.get())->get_sinfos(); - } else { - kvd->last_main_model_sinfos.clear(); - } - } - } - } + mctx = this->initialize_decode_context(batch_inp, output_all); if (!mctx) { return -2; @@ -3149,3 +3108,77 @@ void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { ctx->kv_cache_seq_rm(seq_id, p0, p1); } + +/* + Initializes the memory context for a decode operation. + The logic follows a specific priority: + 1. Warmup: Always use a standard batch initialization. + 2. Forced S-Info (MTP Updates): If a specific KV cache layout is forced, use it. + 3. Default: Use a standard batch initialization, and if it's a main model pass, + save the resulting s-info for potential future reuse by MTP. +*/ +std::unique_ptr llama_context::initialize_decode_context(const llama_batch & batch_inp, const bool output_all) { + auto * kvd = static_cast(kv_cache_data); + std::unique_ptr mctx; + + if (cparams.warmup) { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + } else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) { + LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n"); + mctx = static_cast(memory.get())->init_batch_with_sinfos( + *balloc, cparams.n_ubatch, *kvd->forced_sinfos, true + ); + } else { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + + if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { + if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) { + kvd->last_main_model_sinfos = static_cast(mctx.get())->get_sinfos(); + } else { + kvd->last_main_model_sinfos.clear(); + } + } + } + + return mctx; +} + + +bool llama_context::prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params) { + + const char * target_tensor_name = "result_embd_pooled"; + ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); + + const float * source_hidden_state = nullptr; + if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + source_hidden_state = this->embd; + } else { // MTP_OP_DRAFT_GEN + source_hidden_state = this->draft_input_hidden_state; + } + + if (source_hidden_state != nullptr && hidden_states_input != nullptr) { + const size_t n_embd = this->model.hparams.n_embd; + const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1; + double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd); + + const char * op_type; + if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + op_type = "MTP_UPDATE"; + } else { // MTP_OP_DRAFT_GEN + op_type = "DRAFT_GEN"; + } + + LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum); + + ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); + } else { + LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n", + __func__, target_tensor_name); + return false; + } + + return true; +} diff --git a/src/llama-context.h b/src/llama-context.h index 70ca4e0832..ab854c1af1 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -231,6 +231,14 @@ private: llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; + // Methods for MTP decode + std::unique_ptr initialize_decode_context(const llama_batch & batch_inp, const bool output_all); + + bool prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params); + // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io);