From 1318b2de82716710b9853e07bd640443a5a025bb Mon Sep 17 00:00:00 2001 From: samuel Date: Sun, 14 Sep 2025 10:22:59 -0300 Subject: [PATCH] mtp-batch (wip): move mtp execution to batch format --- common/speculative.cpp | 47 +++++++------ include/llama.h | 5 +- src/llama-batch.cpp | 15 ++-- src/llama-context.cpp | 152 +++++++++++++++++++++++++---------------- src/llama-graph.cpp | 20 ++++++ src/llama-graph.h | 1 + src/llama-model.cpp | 52 ++++---------- src/llama-model.h | 3 +- 8 files changed, 166 insertions(+), 129 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 77ed75913d..d13666c9f9 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -374,47 +374,54 @@ llama_token mtp_speculative_gen_draft( return -1; } - llama_batch batch = llama_batch_init(1, 0, 1); - common_batch_add(batch, id_last, n_past, {0}, true); + llama_batch mtp_batch = llama_batch_init(1, 0, 1); + common_batch_add(mtp_batch, id_last, n_past, {0}, true); + mtp_batch.update_mtp_kv = true; - llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); + llama_decode(ctx, mtp_batch); + llama_batch_free(mtp_batch); const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_n_vocab(vocab); - llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); - cur_p->size = n_vocab; for (int i = 0; i < n_vocab; ++i) { cur_p->data[i].id = i; - cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i]; + cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // TODO: check if position 0 is the right } cur_p->sorted = false; - common_sampler_apply_chain(smpl, cur_p); - - const llama_token id = cur_p->data[0].id; - - llama_batch_free(batch); - - return id; + + return cur_p->data[0].id; } void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start, size_t n_tokens) { - mtp_kv_update_data token; - + if (tokens.empty()) { + tokens.clear(); + return; + } if (n_tokens < 0) { n_tokens = tokens.size(); } + const size_t n_to_process = std::min((size_t)tokens.size(), n_tokens); - for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) { - token = tokens[i]; - //fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start)); - - mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start); + LOG_DBG( + "[MTP BATCHING] mtp_update_kv_cache call for %zu tokens.\n", + n_to_process + ); + llama_batch mtp_batch = llama_batch_init(n_to_process, 0, 1); + + for (size_t i = 0; i < n_to_process; ++i) { + const mtp_kv_update_data& token_data = tokens[i]; + common_batch_add(mtp_batch, token_data.id, token_data.n_past, {0}, false); } + mtp_batch.update_mtp_kv = true; + + llama_decode(ctx, mtp_batch); + + llama_batch_free(mtp_batch); tokens.clear(); } \ No newline at end of file diff --git a/include/llama.h b/include/llama.h index e43cd83468..0916bb9c5f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -230,6 +230,7 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" + bool update_mtp_kv; } llama_batch; enum llama_model_kv_override_type { @@ -1454,8 +1455,8 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); + // LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + // const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); #ifdef __cplusplus } diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index ff73429301..589b138531 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -834,13 +834,14 @@ struct llama_batch llama_batch_get_one( struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch batch = { - /*n_tokens =*/ 0, - /*tokens =*/ nullptr, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*update_mtp_kv =*/ false, }; if (embd) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fb285a8d29..69549edb1c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1070,6 +1070,7 @@ int llama_context::decode(const llama_batch & batch_inp) { }; int64_t n_outputs_prev = 0; + const bool do_mtp_kv_update = batch_inp.update_mtp_kv; do { const auto & ubatch = mctx->get_ubatch(); @@ -1129,6 +1130,39 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} + if (do_mtp_kv_update) { + LLAMA_LOG_INFO( + "[MTP BATCHING] Processando MTP KV update para um ubatch de %u tokens.\n", + ubatch.n_tokens + ); + auto res_mtp = std::make_unique(graph_max_nodes()); + + auto params_mtp = mtp_graph_params(res_mtp.get(), ubatch, mctx.get()); + ggml_backend_sched_t sched_mtp = params_mtp.sched; + + auto * gf_mtp = model.build_mtp_graph(params_mtp); + if (gf_mtp) { + ggml_backend_sched_alloc_graph(sched_mtp, gf_mtp); + + ggml_tensor* prev_embedding_tensor = res->get_embd(); + ggml_tensor* embd_input_mtp = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embeddings_batch_input"); + + // ggml_backend_tensor_set(embd_input_mtp, prev_embedding_tensor->data, 0, ggml_nbytes(prev_embedding_tensor)); + ggml_backend_tensor_copy(prev_embedding_tensor, embd_input_mtp); + + ggml_backend_sched_graph_compute(sched_mtp, gf_mtp); + + if (ubatch.output[0]) { + struct ggml_tensor * logits_mtp = res_mtp->get_logits(); + if (logits_mtp) { + float * logits_dest = logits + n_outputs_prev * n_vocab; + ggml_backend_tensor_get(logits_mtp, logits_dest, 0, ggml_nbytes(logits_mtp)); + } + } + } + ggml_backend_sched_free(sched_mtp); + } + auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; embd_tensor = res->get_embd(); @@ -2995,79 +3029,79 @@ void llama_opt_epoch( callback_eval); } -void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { +// void llama_build_and_execute_mtp_graph(struct llama_context * ctx, +// const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { - const auto * model = llama_get_model(ctx); +// const auto * model = llama_get_model(ctx); - auto res_mtp = std::make_unique(ctx->graph_max_nodes()); - std::unique_ptr mctx = ctx->mtp_memory_batch(batch_inp); +// auto res_mtp = std::make_unique(ctx->graph_max_nodes()); +// std::unique_ptr mctx = ctx->mtp_memory_batch(batch_inp); - std::vector idxs; - idxs.push_back(n_past); - llama_kv_cache_unified::slot_info sinfo = { - /*.s0 =*/ 0, - /*.s1 =*/ 0, - /*.strm =*/ { 0 }, - /*.idxs =*/ { idxs }, - }; - llama_kv_cache_unified::slot_info_vec_t sinfos; - sinfos.push_back(sinfo); +// std::vector idxs; +// idxs.push_back(n_past); +// llama_kv_cache_unified::slot_info sinfo = { +// /*.s0 =*/ 0, +// /*.s1 =*/ 0, +// /*.strm =*/ { 0 }, +// /*.idxs =*/ { idxs }, +// }; +// llama_kv_cache_unified::slot_info_vec_t sinfos; +// sinfos.push_back(sinfo); - static_cast(mctx.get())->set_sinfos(sinfos); - const auto& ubatch_mtp = mctx->get_ubatch(); +// static_cast(mctx.get())->set_sinfos(sinfos); +// const auto& ubatch_mtp = mctx->get_ubatch(); - //llama_ubatch ubatch_mtp; - //ubatch_mtp.n_tokens = 1; - //ubatch_mtp.pos = &n_past; +// //llama_ubatch ubatch_mtp; +// //ubatch_mtp.n_tokens = 1; +// //ubatch_mtp.pos = &n_past; - auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get())); - ggml_backend_sched_t sched = params_mtp->sched; +// auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get())); +// ggml_backend_sched_t sched = params_mtp->sched; - auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); +// auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); - //if (mctx && !mctx->set_n_kv()) { - // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); - //} - static_cast(mctx.get())->set_n_kv(); +// //if (mctx && !mctx->set_n_kv()) { +// // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); +// //} +// static_cast(mctx.get())->set_n_kv(); - auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); +// auto * gf = model->build_mtp_graph(*params_mtp); - if (!gf) { - LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__); - if (sched) ggml_backend_sched_free(sched); - return; - } +// if (!gf) { +// LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__); +// if (sched) ggml_backend_sched_free(sched); +// return; +// } - ggml_backend_sched_reset(sched); // clear the allocation of the previous graph - ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it +// ggml_backend_sched_reset(sched); // clear the allocation of the previous graph +// ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it - ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); - ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors +// ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); +// ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors - ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input"); - ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors +// ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input"); +// ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors - ggml_backend_sched_graph_compute(sched, gf); // execute the graph +// ggml_backend_sched_graph_compute(sched, gf); // execute the graph - struct ggml_tensor * logits_mtp = res_mtp->get_logits(); +// struct ggml_tensor * logits_mtp = res_mtp->get_logits(); - if (logits_mtp) { - float * logits_dest = ctx->get_logits_ith(last_tok_idx); - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp); - if (backend_res) { - // ggml_backend_tensor_get is the function for GPU->CPU copies. - // We are copying a single 32-bit integer. - ggml_backend_tensor_get(logits_mtp, - logits_dest, // Pointer to our C++ variable - 0, // Starting offset in bytes - ggml_nbytes(logits_mtp)); // Number of bytes to copy - } else { - LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__); - } - } else { - LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); - } +// if (logits_mtp) { +// float * logits_dest = ctx->get_logits_ith(last_tok_idx); +// ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp); +// if (backend_res) { +// // ggml_backend_tensor_get is the function for GPU->CPU copies. +// // We are copying a single 32-bit integer. +// ggml_backend_tensor_get(logits_mtp, +// logits_dest, // Pointer to our C++ variable +// 0, // Starting offset in bytes +// ggml_nbytes(logits_mtp)); // Number of bytes to copy +// } else { +// LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__); +// } +// } else { +// LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); +// } - ggml_backend_sched_free(sched); -} \ No newline at end of file +// ggml_backend_sched_free(sched); +// } \ No newline at end of file diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 053c72d6dc..be7de40454 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1074,6 +1074,26 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { return cur; } + +ggml_tensor * llm_graph_context::build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const { + auto inp = std::make_unique(); + ggml_tensor * cur = nullptr; + + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_name(inp->tokens, "mtp_inp_tokens"); + ggml_set_input(inp->tokens); + + cur = ggml_get_rows(ctx0, mtp_tok_embd, inp->tokens); + } else { + GGML_ABORT("fatal error: MTP update expects token IDs, not embeddings"); + } + + cb(cur, "mtp_inp_embd", -1); + res->add_input(std::move(inp)); + return cur; +} + ggml_tensor * llm_graph_context::build_inp_pos() const { auto inp = std::make_unique(hparams.n_pos_per_embd()); diff --git a/src/llama-graph.h b/src/llama-graph.h index 10702ed219..57772d9c15 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -664,6 +664,7 @@ struct llm_graph_context { // ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; + ggml_tensor * build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const; ggml_tensor * build_inp_pos() const; ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index dd4bf211b7..cce99ef3b1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13946,54 +13946,29 @@ struct llm_build_glm4_moe : public llm_graph_context { }; struct llm_build_glm4_moe_mtp : public llm_graph_context { - llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params, - // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization - llama_token last_token_id, int n_past - ) : llm_graph_context(params) { + llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - // Assuming a single MTP layer at the end const int il = hparams.n_layer - 1; const auto & mtp_layer = model.layers[il]; - // ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - // ggml_set_i32(inp_pos, n_past); ggml_tensor * inp_pos = build_inp_pos(); - - //llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; auto * inp_attn = build_attn_inp_kv_unified(); - // get MTP embedding for last (conventionally sampled) token - // ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - // LLAMA_LOG_INFO("step: '%d'\n", 5641); - // ggml_set_i32(inp_token_id, last_token_id); - //ggml_set_no_alloc(ctx0, false); - //LLAMA_LOG_INFO("last token id: '%d'\n", last_token_id); - - ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - ggml_set_name(inp_token_id, "mtp_token_id_input"); - ggml_set_input(inp_token_id); - - //ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id); - //ggml_set_no_alloc(ctx0, true); + ggml_tensor* prev_embeddings_batch = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_embd, n_tokens); + ggml_set_name(prev_embeddings_batch, "mtp_prev_embeddings_batch_input"); + ggml_set_input(prev_embeddings_batch); - ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); + ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens); + ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * hidden_state_norm = build_norm(prev_embeddings_batch, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); - ggml_tensor* prev_embedding_leaf = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, model.hparams.n_embd); - ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_input"); - ggml_set_input(prev_embedding_leaf); + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); - // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states) - ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); - //token_emb_norm = ggml_cont(ctx0, token_emb_norm); - //hidden_state_norm = ggml_cont(ctx0, hidden_state_norm); - - ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat - - ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj + ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // now proceed through last layer (skipped in main model) ggml_tensor * inpSA = cur; @@ -14090,11 +14065,11 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); - cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); - + res->t_logits = cur; + ggml_build_forward_expand(gf, res->t_logits); } }; @@ -18689,14 +18664,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { return llm->res->get_gf(); } -ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, - llama_token last_token_id, int n_past) const { +ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params) const { std::unique_ptr llm; switch (arch) { case LLM_ARCH_GLM4_MOE: { - llm = std::make_unique(*this, params, last_token_id, n_past); + llm = std::make_unique(*this, params); } break; default: GGML_ABORT("fatal error"); diff --git a/src/llama-model.h b/src/llama-model.h index b28a37488f..f5f9452a5b 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -475,8 +475,7 @@ struct llama_model { // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; - ggml_cgraph * build_mtp_graph(const llm_graph_params & params, - llama_token last_token_id, int n_past) const; + ggml_cgraph * build_mtp_graph(const llm_graph_params& params) const; private: struct impl;