mtp-batch(refactor): Extract decode context and MTP input logic into helper methods

This commit is contained in:
samuel 2025-10-10 17:24:34 -03:00
parent 913af8f48d
commit a99709d0c1
2 changed files with 84 additions and 43 deletions

View File

@ -794,28 +794,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
}
if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation
const char * target_tensor_name = "result_embd_pooled";
ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name);
const float * source_hidden_state = nullptr;
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
source_hidden_state = this->embd;
} else {
source_hidden_state = this->draft_input_hidden_state;
}
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
const size_t n_embd = this->model.hparams.n_embd;
const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd);
const char * op_type = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) ? "MTP_UPDATE" : "DRAFT_GEN";
LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum);
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
} else {
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",
__func__, target_tensor_name);
if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) {
ret = GGML_STATUS_FAILED;
return nullptr;
}
@ -1089,27 +1068,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
std::unique_ptr<llama_memory_context_i> mctx;
while (true) {
if (cparams.warmup) {
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
} else {
if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) {
LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n");
mctx = static_cast<llama_kv_cache_unified *>(memory.get())->init_batch_with_sinfos(
*balloc, cparams.n_ubatch, *kvd->forced_sinfos, true
);
} else {
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) {
kvd->last_main_model_sinfos = static_cast<llama_kv_cache_unified_context *>(mctx.get())->get_sinfos();
} else {
kvd->last_main_model_sinfos.clear();
}
}
}
}
mctx = this->initialize_decode_context(batch_inp, output_all);
if (!mctx) {
return -2;
@ -3149,3 +3108,77 @@ void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
ctx->kv_cache_seq_rm(seq_id, p0, p1);
}
/*
Initializes the memory context for a decode operation.
The logic follows a specific priority:
1. Warmup: Always use a standard batch initialization.
2. Forced S-Info (MTP Updates): If a specific KV cache layout is forced, use it.
3. Default: Use a standard batch initialization, and if it's a main model pass,
save the resulting s-info for potential future reuse by MTP.
*/
std::unique_ptr<llama_memory_context_i> llama_context::initialize_decode_context(const llama_batch & batch_inp, const bool output_all) {
auto * kvd = static_cast<llama_context_kv_cache_data *>(kv_cache_data);
std::unique_ptr<llama_memory_context_i> mctx;
if (cparams.warmup) {
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
} else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) {
LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n");
mctx = static_cast<llama_kv_cache_unified *>(memory.get())->init_batch_with_sinfos(
*balloc, cparams.n_ubatch, *kvd->forced_sinfos, true
);
} else {
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) {
kvd->last_main_model_sinfos = static_cast<llama_kv_cache_unified_context *>(mctx.get())->get_sinfos();
} else {
kvd->last_main_model_sinfos.clear();
}
}
}
return mctx;
}
bool llama_context::prepare_mtp_graph_inputs(
llm_graph_result * res,
const llama_ubatch & ubatch,
const llama_mtp_params & mtp_params) {
const char * target_tensor_name = "result_embd_pooled";
ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name);
const float * source_hidden_state = nullptr;
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
source_hidden_state = this->embd;
} else { // MTP_OP_DRAFT_GEN
source_hidden_state = this->draft_input_hidden_state;
}
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
const size_t n_embd = this->model.hparams.n_embd;
const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd);
const char * op_type;
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
op_type = "MTP_UPDATE";
} else { // MTP_OP_DRAFT_GEN
op_type = "DRAFT_GEN";
}
LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum);
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
} else {
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",
__func__, target_tensor_name);
return false;
}
return true;
}

View File

@ -231,6 +231,14 @@ private:
llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const;
// Methods for MTP decode
std::unique_ptr<llama_memory_context_i> initialize_decode_context(const llama_batch & batch_inp, const bool output_all);
bool prepare_mtp_graph_inputs(
llm_graph_result * res,
const llama_ubatch & ubatch,
const llama_mtp_params & mtp_params);
// TODO: read/write lora adapters and cvec
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);