mtp-batch(refactor): Extract decode context and MTP input logic into helper methods
This commit is contained in:
parent
913af8f48d
commit
a99709d0c1
|
|
@ -794,28 +794,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
}
|
||||
|
||||
if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation
|
||||
const char * target_tensor_name = "result_embd_pooled";
|
||||
ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name);
|
||||
|
||||
const float * source_hidden_state = nullptr;
|
||||
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
source_hidden_state = this->embd;
|
||||
} else {
|
||||
source_hidden_state = this->draft_input_hidden_state;
|
||||
}
|
||||
|
||||
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
|
||||
const size_t n_embd = this->model.hparams.n_embd;
|
||||
const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
|
||||
double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd);
|
||||
const char * op_type = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) ? "MTP_UPDATE" : "DRAFT_GEN";
|
||||
|
||||
LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum);
|
||||
|
||||
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",
|
||||
__func__, target_tensor_name);
|
||||
if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) {
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -1089,27 +1068,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
std::unique_ptr<llama_memory_context_i> mctx;
|
||||
|
||||
while (true) {
|
||||
if (cparams.warmup) {
|
||||
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||
} else {
|
||||
if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) {
|
||||
LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n");
|
||||
|
||||
mctx = static_cast<llama_kv_cache_unified *>(memory.get())->init_batch_with_sinfos(
|
||||
*balloc, cparams.n_ubatch, *kvd->forced_sinfos, true
|
||||
);
|
||||
} else {
|
||||
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||
|
||||
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
|
||||
if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
kvd->last_main_model_sinfos = static_cast<llama_kv_cache_unified_context *>(mctx.get())->get_sinfos();
|
||||
} else {
|
||||
kvd->last_main_model_sinfos.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mctx = this->initialize_decode_context(batch_inp, output_all);
|
||||
|
||||
if (!mctx) {
|
||||
return -2;
|
||||
|
|
@ -3149,3 +3108,77 @@ void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
ctx->kv_cache_seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
|
||||
/*
|
||||
Initializes the memory context for a decode operation.
|
||||
The logic follows a specific priority:
|
||||
1. Warmup: Always use a standard batch initialization.
|
||||
2. Forced S-Info (MTP Updates): If a specific KV cache layout is forced, use it.
|
||||
3. Default: Use a standard batch initialization, and if it's a main model pass,
|
||||
save the resulting s-info for potential future reuse by MTP.
|
||||
*/
|
||||
std::unique_ptr<llama_memory_context_i> llama_context::initialize_decode_context(const llama_batch & batch_inp, const bool output_all) {
|
||||
auto * kvd = static_cast<llama_context_kv_cache_data *>(kv_cache_data);
|
||||
std::unique_ptr<llama_memory_context_i> mctx;
|
||||
|
||||
if (cparams.warmup) {
|
||||
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||
} else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) {
|
||||
LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n");
|
||||
mctx = static_cast<llama_kv_cache_unified *>(memory.get())->init_batch_with_sinfos(
|
||||
*balloc, cparams.n_ubatch, *kvd->forced_sinfos, true
|
||||
);
|
||||
} else {
|
||||
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||
|
||||
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
|
||||
if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
kvd->last_main_model_sinfos = static_cast<llama_kv_cache_unified_context *>(mctx.get())->get_sinfos();
|
||||
} else {
|
||||
kvd->last_main_model_sinfos.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mctx;
|
||||
}
|
||||
|
||||
|
||||
bool llama_context::prepare_mtp_graph_inputs(
|
||||
llm_graph_result * res,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_mtp_params & mtp_params) {
|
||||
|
||||
const char * target_tensor_name = "result_embd_pooled";
|
||||
ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name);
|
||||
|
||||
const float * source_hidden_state = nullptr;
|
||||
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
source_hidden_state = this->embd;
|
||||
} else { // MTP_OP_DRAFT_GEN
|
||||
source_hidden_state = this->draft_input_hidden_state;
|
||||
}
|
||||
|
||||
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
|
||||
const size_t n_embd = this->model.hparams.n_embd;
|
||||
const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
|
||||
double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd);
|
||||
|
||||
const char * op_type;
|
||||
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
op_type = "MTP_UPDATE";
|
||||
} else { // MTP_OP_DRAFT_GEN
|
||||
op_type = "DRAFT_GEN";
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum);
|
||||
|
||||
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",
|
||||
__func__, target_tensor_name);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -231,6 +231,14 @@ private:
|
|||
|
||||
llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const;
|
||||
|
||||
// Methods for MTP decode
|
||||
std::unique_ptr<llama_memory_context_i> initialize_decode_context(const llama_batch & batch_inp, const bool output_all);
|
||||
|
||||
bool prepare_mtp_graph_inputs(
|
||||
llm_graph_result * res,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_mtp_params & mtp_params);
|
||||
|
||||
// TODO: read/write lora adapters and cvec
|
||||
size_t state_write_data(llama_io_write_i & io);
|
||||
size_t state_read_data (llama_io_read_i & io);
|
||||
|
|
|
|||
Loading…
Reference in New Issue