From 042eb8a829876ed175320df9c8133bcea0c40460 Mon Sep 17 00:00:00 2001 From: samuel Date: Sun, 21 Sep 2025 21:29:00 -0300 Subject: [PATCH] mtp-batch (wip): merge mtp and model graph --- src/llama-context.cpp | 84 ++++++++--------------------------- src/llama-context.h | 8 ++-- src/llama-graph.h | 1 + src/llama-model.cpp | 98 +++++++++++++++++------------------------ src/llama-model.h | 1 - tools/server/server.cpp | 8 +++- 6 files changed, 70 insertions(+), 130 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 69549edb1c..754ad6a041 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -729,7 +729,8 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { +llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, + bool do_mtp_kv_update) { if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; @@ -741,7 +742,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // the new graph parameters // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype); + const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update); if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); @@ -781,7 +782,15 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); } + const int64_t t_exec_start_us = ggml_time_us(); const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); + const int64_t t_exec_end_us = ggml_time_us(); + LLAMA_LOG_INFO( + "[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n", + (t_exec_end_us - t_exec_start_us) / 1000.0, + ubatch.n_tokens, + do_mtp_kv_update ? "yes" : "no" + ); if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); ret = status; @@ -850,7 +859,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false); cparams.causal_attn = causal_attn_org; @@ -1092,7 +1101,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache @@ -1130,39 +1139,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - if (do_mtp_kv_update) { - LLAMA_LOG_INFO( - "[MTP BATCHING] Processando MTP KV update para um ubatch de %u tokens.\n", - ubatch.n_tokens - ); - auto res_mtp = std::make_unique(graph_max_nodes()); - - auto params_mtp = mtp_graph_params(res_mtp.get(), ubatch, mctx.get()); - ggml_backend_sched_t sched_mtp = params_mtp.sched; - - auto * gf_mtp = model.build_mtp_graph(params_mtp); - if (gf_mtp) { - ggml_backend_sched_alloc_graph(sched_mtp, gf_mtp); - - ggml_tensor* prev_embedding_tensor = res->get_embd(); - ggml_tensor* embd_input_mtp = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embeddings_batch_input"); - - // ggml_backend_tensor_set(embd_input_mtp, prev_embedding_tensor->data, 0, ggml_nbytes(prev_embedding_tensor)); - ggml_backend_tensor_copy(prev_embedding_tensor, embd_input_mtp); - - ggml_backend_sched_graph_compute(sched_mtp, gf_mtp); - - if (ubatch.output[0]) { - struct ggml_tensor * logits_mtp = res_mtp->get_logits(); - if (logits_mtp) { - float * logits_dest = logits + n_outputs_prev * n_vocab; - ggml_backend_tensor_get(logits_mtp, logits_dest, 0, ggml_nbytes(logits_mtp)); - } - } - } - ggml_backend_sched_free(sched_mtp); - } - auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; embd_tensor = res->get_embd(); @@ -1442,7 +1418,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false); res->reset(); @@ -1462,8 +1438,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, - const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + const llama_memory_context_i * mctx, + llm_graph_type gtype, + bool update_mtp_kv) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1476,36 +1453,13 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.update_mtp_kv =*/ update_mtp_kv, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, }; } -llm_graph_params llama_context::mtp_graph_params( - llm_graph_result * res, - const llama_ubatch& ubatch, - const llama_memory_context_i * mctx) { - size_t n_nodes = std::max(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer)); - ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes); - return { - /*.arch =*/ model.arch, - /*.hparams =*/ model.hparams, - /*.cparams =*/ cparams, - /*.ubatch =*/ ubatch, - /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, - /*.sched =*/ temp_sched, - /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, - /*.mctx =*/ mctx, - /*.cross =*/ &cross, - /*.n_outputs =*/ 1, - /*.cb =*/ graph_get_cb(temp_sched), - /*.res =*/ res, - }; -} - std::unique_ptr llama_context::mtp_memory_batch(const llama_batch& batch_inp) { const auto& vocab = model.vocab; const auto& hparams = model.hparams; @@ -2240,7 +2194,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false); res->reset(); diff --git a/src/llama-context.h b/src/llama-context.h index e8ea3a4c9b..88f63e88d7 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -99,7 +99,8 @@ struct llama_context { const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, - ggml_status & ret); + ggml_status & ret, + const bool do_mtp_kv_update); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); @@ -200,8 +201,6 @@ public: // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); - llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx); - void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); @@ -213,7 +212,8 @@ private: llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const; + llm_graph_type gtype, + bool update_mtp_kv) const; llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; diff --git a/src/llama-graph.h b/src/llama-graph.h index 57772d9c15..3f8fe8e979 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -402,6 +402,7 @@ struct llm_graph_params { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + bool update_mtp_kv; uint32_t n_outputs; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index cce99ef3b1..c499870710 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13787,7 +13787,8 @@ struct llm_build_glm4 : public llm_graph_context { }; struct llm_build_glm4_moe : public llm_graph_context { - llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params, bool build_mtp_path) + : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13932,68 +13933,57 @@ struct llm_build_glm4_moe : public llm_graph_context { cur = inpL; cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); + // cb(cur, "result_norm", -1); res->t_embd = cur; - // lm_head - cur = build_lora_mm(model.output, cur); + if (build_mtp_path) { + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + + ggml_tensor * mtp_logits = build_mtp_tail(mtp_layer, cur, n_embd_head); + res->t_logits = mtp_logits; + } else { + // lm_head + cur = build_lora_mm(model.output, cur); + res->t_logits = cur; + } - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); + ggml_build_forward_expand(gf, res->t_logits); } -}; - -struct llm_build_glm4_moe_mtp : public llm_graph_context { - llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); +private: + ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, + int64_t n_embd_head + ) { const int il = hparams.n_layer - 1; - const auto & mtp_layer = model.layers[il]; ggml_tensor * inp_pos = build_inp_pos(); auto * inp_attn = build_attn_inp_kv_unified(); - - ggml_tensor* prev_embeddings_batch = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_embd, n_tokens); - ggml_set_name(prev_embeddings_batch, "mtp_prev_embeddings_batch_input"); - ggml_set_input(prev_embeddings_batch); - ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); - ggml_tensor * hidden_state_norm = build_norm(prev_embeddings_batch, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); - + ggml_tensor * hidden_state_norm = build_norm(prev_embeddings, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); - ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // now proceed through last layer (skipped in main model) ggml_tensor * inpSA = cur; - // Pre-attention norm for the MTP block - ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il); + cur = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il); // self-attention { ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur); - if (mtp_layer.bq) { - Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); - } + if (mtp_layer.bq) Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); cb(Qcur, "Qcur", il); ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur); - if (mtp_layer.bk) { - Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); - } + if (mtp_layer.bk) Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur); - if (mtp_layer.bv) { - Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); - } + if (mtp_layer.bv) Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -14025,10 +14015,10 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - + cur = build_attn(inp_attn, - mtp_layer.wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + mtp_layer.wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); @@ -14068,9 +14058,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); - res->t_logits = cur; - - ggml_build_forward_expand(gf, res->t_logits); + return cur; } }; @@ -18299,8 +18287,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { + const int64_t t_start_us = ggml_time_us(); + std::unique_ptr llm; + const bool build_mtp = params.update_mtp_kv; + switch (arch) { case LLM_ARCH_LLAMA: { @@ -18519,7 +18511,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GLM4_MOE: { - llm = std::make_unique(*this, params); + llm = std::make_unique(*this, params, build_mtp); } break; case LLM_ARCH_BITNET: { @@ -18660,22 +18652,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); - - return llm->res->get_gf(); -} - -ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params) const { - std::unique_ptr llm; - - switch (arch) { - case LLM_ARCH_GLM4_MOE: - { - llm = std::make_unique(*this, params); - } break; - default: - GGML_ABORT("fatal error"); - } - + const int64_t t_end_us = ggml_time_us(); // Fim do cronômetro + LLAMA_LOG_INFO( + "[PERF] Graph build time: %.2f ms (MTP path: %s)\n", + (t_end_us - t_start_us) / 1000.0, + build_mtp ? "yes" : "no" + ); return llm->res->get_gf(); } diff --git a/src/llama-model.h b/src/llama-model.h index f5f9452a5b..6fcd74d57f 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -475,7 +475,6 @@ struct llama_model { // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; - ggml_cgraph * build_mtp_graph(const llm_graph_params& params) const; private: struct impl; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 34053cd040..84a0e6fc15 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1739,7 +1739,7 @@ struct server_queue { while (true) { QUE_DBG("%s", "processing new tasks\n"); - + const int64_t t_turn_start_us = ggml_time_us(); while (true) { std::unique_lock lock(mutex_tasks); if (!running) { @@ -1762,7 +1762,11 @@ struct server_queue { QUE_DBG("%s", "update slots\n"); callback_update_slots(); - + const int64_t t_turn_end_us = ggml_time_us(); + SRV_DBG( + "[PERF] Server turn time: %.2f ms\n", + (t_turn_end_us - t_turn_start_us) / 1000.0 + ); QUE_DBG("%s", "waiting for new tasks\n"); { std::unique_lock lock(mutex_tasks);