mtp-batch (wip): merge mtp and model graph

This commit is contained in:
samuel 2025-09-21 21:29:00 -03:00
parent 1318b2de82
commit 042eb8a829
6 changed files with 70 additions and 130 deletions

View File

@ -729,7 +729,8 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret,
bool do_mtp_kv_update) {
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED;
@ -741,7 +742,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
// the new graph parameters
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
const auto gparams = graph_params(res, ubatch, mctx, gtype);
const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update);
if (!graph_reuse_disable && res->can_reuse(gparams)) {
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
@ -781,7 +782,15 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
}
const int64_t t_exec_start_us = ggml_time_us();
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
const int64_t t_exec_end_us = ggml_time_us();
LLAMA_LOG_INFO(
"[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n",
(t_exec_end_us - t_exec_start_us) / 1000.0,
ubatch.n_tokens,
do_mtp_kv_update ? "yes" : "no"
);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
ret = status;
@ -850,7 +859,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
cparams.causal_attn = false;
ggml_status status;
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false);
cparams.causal_attn = causal_attn_org;
@ -1092,7 +1101,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
ggml_status status;
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update);
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@ -1130,39 +1139,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
if (do_mtp_kv_update) {
LLAMA_LOG_INFO(
"[MTP BATCHING] Processando MTP KV update para um ubatch de %u tokens.\n",
ubatch.n_tokens
);
auto res_mtp = std::make_unique<llm_graph_result>(graph_max_nodes());
auto params_mtp = mtp_graph_params(res_mtp.get(), ubatch, mctx.get());
ggml_backend_sched_t sched_mtp = params_mtp.sched;
auto * gf_mtp = model.build_mtp_graph(params_mtp);
if (gf_mtp) {
ggml_backend_sched_alloc_graph(sched_mtp, gf_mtp);
ggml_tensor* prev_embedding_tensor = res->get_embd();
ggml_tensor* embd_input_mtp = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embeddings_batch_input");
// ggml_backend_tensor_set(embd_input_mtp, prev_embedding_tensor->data, 0, ggml_nbytes(prev_embedding_tensor));
ggml_backend_tensor_copy(prev_embedding_tensor, embd_input_mtp);
ggml_backend_sched_graph_compute(sched_mtp, gf_mtp);
if (ubatch.output[0]) {
struct ggml_tensor * logits_mtp = res_mtp->get_logits();
if (logits_mtp) {
float * logits_dest = logits + n_outputs_prev * n_vocab;
ggml_backend_tensor_get(logits_mtp, logits_dest, 0, ggml_nbytes(logits_mtp));
}
}
}
ggml_backend_sched_free(sched_mtp);
}
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
embd_tensor = res->get_embd();
@ -1442,7 +1418,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
auto * res = gf_res_reserve.get();
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false);
res->reset();
@ -1462,8 +1438,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
llm_graph_params llama_context::graph_params(
llm_graph_result * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const {
const llama_memory_context_i * mctx,
llm_graph_type gtype,
bool update_mtp_kv) const {
return {
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
@ -1476,36 +1453,13 @@ llm_graph_params llama_context::graph_params(
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.update_mtp_kv =*/ update_mtp_kv,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
/*.res =*/ res,
};
}
llm_graph_params llama_context::mtp_graph_params(
llm_graph_result * res,
const llama_ubatch& ubatch,
const llama_memory_context_i * mctx) {
size_t n_nodes = std::max<uint32_t>(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer));
ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes);
return {
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
/*.cparams =*/ cparams,
/*.ubatch =*/ ubatch,
/*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
/*.sched =*/ temp_sched,
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ 1,
/*.cb =*/ graph_get_cb(temp_sched),
/*.res =*/ res,
};
}
std::unique_ptr<llama_memory_context_i> llama_context::mtp_memory_batch(const llama_batch& batch_inp) {
const auto& vocab = model.vocab;
const auto& hparams = model.hparams;
@ -2240,7 +2194,7 @@ void llama_context::opt_epoch_iter(
auto * res = gf_res_prev.get();
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false);
res->reset();

View File

@ -99,7 +99,8 @@ struct llama_context {
const llama_ubatch & ubatch,
llm_graph_type gtype,
llama_memory_context_i * mctx,
ggml_status & ret);
ggml_status & ret,
const bool do_mtp_kv_update);
int encode(const llama_batch & batch_inp);
int decode(const llama_batch & batch_inp);
@ -200,8 +201,6 @@ public:
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx);
void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i);
ggml_backend_sched_t create_temp_scheduler(size_t n_nodes);
@ -213,7 +212,8 @@ private:
llm_graph_result * res,
const llama_ubatch & ubatch,
const llama_memory_context_i * mctx,
llm_graph_type gtype) const;
llm_graph_type gtype,
bool update_mtp_kv) const;
llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const;

View File

@ -402,6 +402,7 @@ struct llm_graph_params {
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_cross * cross;
bool update_mtp_kv;
uint32_t n_outputs;

View File

@ -13787,7 +13787,8 @@ struct llm_build_glm4 : public llm_graph_context {
};
struct llm_build_glm4_moe : public llm_graph_context {
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params, bool build_mtp_path)
: llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@ -13932,68 +13933,57 @@ struct llm_build_glm4_moe : public llm_graph_context {
cur = inpL;
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
// cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
if (build_mtp_path) {
const int il_mtp = hparams.n_layer - 1;
const auto & mtp_layer = model.layers[il_mtp];
ggml_tensor * mtp_logits = build_mtp_tail(mtp_layer, cur, n_embd_head);
res->t_logits = mtp_logits;
} else {
// lm_head
cur = build_lora_mm(model.output, cur);
res->t_logits = cur;
}
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(gf, res->t_logits);
}
};
struct llm_build_glm4_moe_mtp : public llm_graph_context {
llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
private:
ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings,
int64_t n_embd_head
) {
const int il = hparams.n_layer - 1;
const auto & mtp_layer = model.layers[il];
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv_unified();
ggml_tensor* prev_embeddings_batch = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_embd, n_tokens);
ggml_set_name(prev_embeddings_batch, "mtp_prev_embeddings_batch_input");
ggml_set_input(prev_embeddings_batch);
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens);
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
ggml_tensor * hidden_state_norm = build_norm(prev_embeddings_batch, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
ggml_tensor * hidden_state_norm = build_norm(prev_embeddings, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined);
// now proceed through last layer (skipped in main model)
ggml_tensor * inpSA = cur;
// Pre-attention norm for the MTP block
ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
cur = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
// self-attention
{
ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur);
if (mtp_layer.bq) {
Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
}
if (mtp_layer.bq) Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur);
if (mtp_layer.bk) {
Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
}
if (mtp_layer.bk) Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur);
if (mtp_layer.bv) {
Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
}
if (mtp_layer.bv) Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
@ -14025,10 +14015,10 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
mtp_layer.wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
mtp_layer.wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
@ -14068,9 +14058,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
res->t_logits = cur;
ggml_build_forward_expand(gf, res->t_logits);
return cur;
}
};
@ -18299,8 +18287,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
}
ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
const int64_t t_start_us = ggml_time_us();
std::unique_ptr<llm_graph_context> llm;
const bool build_mtp = params.update_mtp_kv;
switch (arch) {
case LLM_ARCH_LLAMA:
{
@ -18519,7 +18511,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break;
case LLM_ARCH_GLM4_MOE:
{
llm = std::make_unique<llm_build_glm4_moe>(*this, params);
llm = std::make_unique<llm_build_glm4_moe>(*this, params, build_mtp);
} break;
case LLM_ARCH_BITNET:
{
@ -18660,22 +18652,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
// add on pooling layer
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
return llm->res->get_gf();
}
ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params) const {
std::unique_ptr<llm_graph_context> llm;
switch (arch) {
case LLM_ARCH_GLM4_MOE:
{
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params);
} break;
default:
GGML_ABORT("fatal error");
}
const int64_t t_end_us = ggml_time_us(); // Fim do cronômetro
LLAMA_LOG_INFO(
"[PERF] Graph build time: %.2f ms (MTP path: %s)\n",
(t_end_us - t_start_us) / 1000.0,
build_mtp ? "yes" : "no"
);
return llm->res->get_gf();
}

View File

@ -475,7 +475,6 @@ struct llama_model {
// TODO: move this to new llm_arch_model_i interface
ggml_cgraph * build_graph(const llm_graph_params & params) const;
ggml_cgraph * build_mtp_graph(const llm_graph_params& params) const;
private:
struct impl;

View File

@ -1739,7 +1739,7 @@ struct server_queue {
while (true) {
QUE_DBG("%s", "processing new tasks\n");
const int64_t t_turn_start_us = ggml_time_us();
while (true) {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!running) {
@ -1762,7 +1762,11 @@ struct server_queue {
QUE_DBG("%s", "update slots\n");
callback_update_slots();
const int64_t t_turn_end_us = ggml_time_us();
SRV_DBG(
"[PERF] Server turn time: %.2f ms\n",
(t_turn_end_us - t_turn_start_us) / 1000.0
);
QUE_DBG("%s", "waiting for new tasks\n");
{
std::unique_lock<std::mutex> lock(mutex_tasks);