From df64508b937784112168aa099644b60fef015f05 Mon Sep 17 00:00:00 2001 From: samuel Date: Sun, 21 Sep 2025 21:55:41 -0300 Subject: [PATCH] mtp-batch (wip): merge glm graphs --- common/speculative.cpp | 19 ++- include/llama.h | 6 +- src/llama-batch.cpp | 1 + src/llama-context.cpp | 161 +++++++++---------- src/llama-context.h | 8 +- src/llama-graph.h | 18 +++ src/llama-model.cpp | 348 ++++++++++++++++++++++------------------ tools/server/server.cpp | 21 ++- 8 files changed, 331 insertions(+), 251 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index d13666c9f9..1604dbd48a 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -373,10 +373,24 @@ llama_token mtp_speculative_gen_draft( if (!smpl) { return -1; } - + const float * draft_input_hidden_state = llama_get_embeddings_ith(ctx, last_tok_idx); + llama_set_draft_input_hidden_state(ctx, draft_input_hidden_state); + llama_batch mtp_batch = llama_batch_init(1, 0, 1); common_batch_add(mtp_batch, id_last, n_past, {0}, true); - mtp_batch.update_mtp_kv = true; + + LOG_INF( + "[DEBUG-DRAFT-IN] Generating draft. id_last=%d, n_past=%d, last_tok_idx=%d\n", + id_last, n_past, last_tok_idx + ); + + mtp_batch.update_mtp_kv = false; + mtp_batch.use_mtp_head = true; + + LOG_INF("[DEBUG-DRAFT-CALL] Calling llama_decode for draft. update_mtp_kv=%s, use_mtp_head=%s\n", + mtp_batch.update_mtp_kv ? "true" : "false", + mtp_batch.use_mtp_head ? "true" : "false" + ); llama_decode(ctx, mtp_batch); llama_batch_free(mtp_batch); @@ -419,6 +433,7 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vectorapply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; @@ -742,7 +742,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // the new graph parameters // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update); + const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update, use_mtp_head); if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); @@ -773,6 +773,29 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } } + if (do_mtp_kv_update || (use_mtp_head && !do_mtp_kv_update)) { // If it is any MTP operation + const char * target_tensor_name = "result_embd_pooled"; + ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); + + const float * source_hidden_state = nullptr; + if (do_mtp_kv_update) { + // Cache warming uses the entire embeddings buffer + source_hidden_state = this->embd; + } else { + // Draft generation uses the specific state + source_hidden_state = this->draft_input_hidden_state; + } + + if (source_hidden_state != nullptr && hidden_states_input != nullptr) { + ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); + } else { + LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n", + __func__, target_tensor_name); + ret = GGML_STATUS_FAILED; + return nullptr; + } + } + // set the input data for the input tensors { //const auto t_start_us = ggml_time_us(); @@ -798,7 +821,12 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } ret = GGML_STATUS_SUCCESS; - + if (do_mtp_kv_update || use_mtp_head) { + ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum"); + if (sum_tensor) { + LLAMA_LOG_WARN("[DEBUG-SUM] MTP input sum node successfully created.\n"); + } + } return res; } @@ -859,7 +887,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false, false); cparams.causal_attn = causal_attn_org; @@ -972,6 +1000,10 @@ int llama_context::encode(const llama_batch & batch_inp) { int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n", + batch_inp.update_mtp_kv ? "true" : "false", + batch_inp.use_mtp_head ? "true" : "false" + ); if (!memory) { LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); @@ -1080,9 +1112,24 @@ int llama_context::decode(const llama_batch & batch_inp) { int64_t n_outputs_prev = 0; const bool do_mtp_kv_update = batch_inp.update_mtp_kv; - + const bool use_mtp_head = batch_inp.use_mtp_head; + const bool is_prompt_warmup = batch_inp.n_tokens > 1 && (this->model.hparams.nextn_predict_layers > 0); + do { const auto & ubatch = mctx->get_ubatch(); + if (ubatch.n_tokens > 0) { + std::string pos_str; + for (uint32_t i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) { + pos_str += std::to_string(ubatch.pos[i]) + " "; + } + LLAMA_LOG_WARN( + "[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Posições: %s...\n", + ubatch.n_tokens, + batch_inp.update_mtp_kv ? "true" : "false", + batch_inp.use_mtp_head ? "true" : "false", + pos_str.c_str() + ); + } // count the outputs in this ubatch { @@ -1101,7 +1148,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache @@ -1139,6 +1186,17 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} + // if (is_prompt_warmup) { + // auto res_mtp = std::make_unique(graph_max_nodes()); + // ggml_status status_mtp; + + // process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status_mtp, do_mtp_kv_update, use_mtp_head); + + // if (status_mtp != GGML_STATUS_SUCCESS) { + // LLAMA_LOG_WARN("%s: Failure in MTP heating ubatch\n", __func__); + // } + // } + auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; embd_tensor = res->get_embd(); @@ -1278,7 +1336,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // overlap with device computation. ggml_backend_sched_reset(sched.get()); } - + if (!do_mtp_kv_update && !use_mtp_head) { + LLAMA_LOG_WARN("[DEBUG-EMBD-WRITE] Main decode completed. ctx->embd (%p) now contains the hidden state for the next draft.\n", (void*)this->embd); + } return 0; } @@ -1418,7 +1478,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false); + const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false, false); res->reset(); @@ -1440,7 +1500,8 @@ llm_graph_params llama_context::graph_params( const llama_ubatch & ubatch, const llama_memory_context_i * mctx, llm_graph_type gtype, - bool update_mtp_kv) const { + bool update_mtp_kv, + bool use_mtp_head) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1454,6 +1515,7 @@ llm_graph_params llama_context::graph_params( /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.update_mtp_kv =*/ update_mtp_kv, + /*.use_mtp_head =*/ use_mtp_head, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, @@ -2194,7 +2256,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false); + const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false, false); res->reset(); @@ -2983,79 +3045,6 @@ void llama_opt_epoch( callback_eval); } -// void llama_build_and_execute_mtp_graph(struct llama_context * ctx, -// const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { - -// const auto * model = llama_get_model(ctx); - -// auto res_mtp = std::make_unique(ctx->graph_max_nodes()); -// std::unique_ptr mctx = ctx->mtp_memory_batch(batch_inp); - -// std::vector idxs; -// idxs.push_back(n_past); -// llama_kv_cache_unified::slot_info sinfo = { -// /*.s0 =*/ 0, -// /*.s1 =*/ 0, -// /*.strm =*/ { 0 }, -// /*.idxs =*/ { idxs }, -// }; -// llama_kv_cache_unified::slot_info_vec_t sinfos; -// sinfos.push_back(sinfo); - -// static_cast(mctx.get())->set_sinfos(sinfos); -// const auto& ubatch_mtp = mctx->get_ubatch(); - -// //llama_ubatch ubatch_mtp; -// //ubatch_mtp.n_tokens = 1; -// //ubatch_mtp.pos = &n_past; - -// auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get())); -// ggml_backend_sched_t sched = params_mtp->sched; - -// auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); - -// //if (mctx && !mctx->set_n_kv()) { -// // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); -// //} -// static_cast(mctx.get())->set_n_kv(); - -// auto * gf = model->build_mtp_graph(*params_mtp); - -// if (!gf) { -// LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__); -// if (sched) ggml_backend_sched_free(sched); -// return; -// } - -// ggml_backend_sched_reset(sched); // clear the allocation of the previous graph -// ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it - -// ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); -// ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors - -// ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input"); -// ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors - -// ggml_backend_sched_graph_compute(sched, gf); // execute the graph - -// struct ggml_tensor * logits_mtp = res_mtp->get_logits(); - -// if (logits_mtp) { -// float * logits_dest = ctx->get_logits_ith(last_tok_idx); -// ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp); -// if (backend_res) { -// // ggml_backend_tensor_get is the function for GPU->CPU copies. -// // We are copying a single 32-bit integer. -// ggml_backend_tensor_get(logits_mtp, -// logits_dest, // Pointer to our C++ variable -// 0, // Starting offset in bytes -// ggml_nbytes(logits_mtp)); // Number of bytes to copy -// } else { -// LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__); -// } -// } else { -// LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); -// } - -// ggml_backend_sched_free(sched); -// } \ No newline at end of file +void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) { + ctx->draft_input_hidden_state = hidden_state; +} \ No newline at end of file diff --git a/src/llama-context.h b/src/llama-context.h index 88f63e88d7..1df3574c27 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -61,6 +61,8 @@ struct llama_context { float * get_embeddings_seq(llama_seq_id seq_id); ggml_tensor * get_embeddings_tensor(); + const float * draft_input_hidden_state = nullptr; + void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -100,7 +102,8 @@ struct llama_context { llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, - const bool do_mtp_kv_update); + const bool do_mtp_kv_update, + const bool use_mtp_head); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); @@ -213,7 +216,8 @@ private: const llama_ubatch & ubatch, const llama_memory_context_i * mctx, llm_graph_type gtype, - bool update_mtp_kv) const; + bool update_mtp_kv, + bool use_mtp_head) const; llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; diff --git a/src/llama-graph.h b/src/llama-graph.h index 3f8fe8e979..40dd83f0bc 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -29,6 +29,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DRAFT, }; enum llm_ffn_op_type { @@ -94,6 +95,20 @@ public: using llm_graph_input_ptr = std::unique_ptr; +class llm_graph_input_mtp_states : public llm_graph_input_i { +public: + llm_graph_input_mtp_states() = default; + virtual ~llm_graph_input_mtp_states() = default; + + void set_input(const llama_ubatch * /*ubatch*/) override {} + + bool can_reuse(const llm_graph_params & /*params*/) override { + return true; + } + + ggml_tensor * states = nullptr; +}; + class llm_graph_input_embd : public llm_graph_input_i { public: llm_graph_input_embd() = default; @@ -403,6 +418,7 @@ struct llm_graph_params { const llama_memory_context_i * mctx; const llama_cross * cross; bool update_mtp_kv; + bool use_mtp_head; uint32_t n_outputs; @@ -451,6 +467,8 @@ struct llm_graph_params { cvec == other.cvec && loras == other.loras && cross == other.cross && + update_mtp_kv == other.update_mtp_kv && + use_mtp_head == other.use_mtp_head && n_outputs == other.n_outputs; } }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c499870710..82c7be49cb 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13787,168 +13787,204 @@ struct llm_build_glm4 : public llm_graph_context { }; struct llm_build_glm4_moe : public llm_graph_context { - llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params, bool build_mtp_path) - : llm_graph_context(params) { + llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); ggml_tensor * cur; - ggml_tensor * inpL; - inpL = build_inp_embd(model.tok_embd); - - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv_unified(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - // Only process up to last layer (skip final NextN layer) - // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { - ggml_tensor * inpSA = inpL; - - // Pre-attention norm - cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - // Apply Q/K norm if available (GLM-4.5 355B variant) - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - } - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); - } - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - } - - if (il == n_transformer_layers - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // Post-attention norm - cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "post_attn_norm", il); - - // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) - if (static_cast(il) < hparams.n_layer_dense_lead) { - // Dense FFN layer - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } else { - // Process routed experts using existing MoE infrastructure - ggml_tensor * routed_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, - (llama_expert_gating_func_type) hparams.expert_gating_func, - il); - cb(routed_out, "ffn_moe_out", il); - - // Process shared expert on original input - ggml_tensor * shared_out = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(shared_out, "ffn_shexp_out", il); - - // Final output: routed_output + shared_output - cur = ggml_add(ctx0, routed_out, shared_out); - cb(cur, "ffn_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; + LLAMA_LOG_WARN( + "[DEBUG-GRAPH-STATE] Building graph. MTP Head=%s, MTP KV Update=%s, n_tokens=%d\n", + params.use_mtp_head ? "true" : "false", + params.update_mtp_kv ? "true" : "false", + n_tokens + ); + // for (int i = 0; i < n_tokens; ++i) { + // LLAMA_LOG_WARN(" - ubatch token[%d]: ID=%d, Pos=%d\n", i, ubatch.token[i], ubatch.pos[i]); + // } + if (n_tokens > 0) { + LLAMA_LOG_WARN( + " - ubatch tokens: [ID=%d, Pos=%d] ... [ID=%d, Pos=%d]\n", + ubatch.token[0], ubatch.pos[0], + ubatch.token[n_tokens-1], ubatch.pos[n_tokens-1] + ); } - cur = inpL; - cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + if (params.use_mtp_head) { + ggml_tensor* hidden_states_from_main_model; - // cb(cur, "result_norm", -1); - res->t_embd = cur; + if (params.update_mtp_kv) { + hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); - if (build_mtp_path) { - const int il_mtp = hparams.n_layer - 1; - const auto & mtp_layer = model.layers[il_mtp]; - - ggml_tensor * mtp_logits = build_mtp_tail(mtp_layer, cur, n_embd_head); - res->t_logits = mtp_logits; + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); } else { - // lm_head - cur = build_lora_mm(model.output, cur); - res->t_logits = cur; + hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd); + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); + + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); + } + res->t_embd = hidden_states_from_main_model; + + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head); + + } else { + ggml_tensor * inpL = build_inp_embd(model.tok_embd); + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_unified(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + // if (params.use_mtp_head) { + // LLAMA_LOG_ERROR("[DEBUG-KV-ERROR] MTP path is running the main layer %d!\n", il); + // } else { + // LLAMA_LOG_WARN("[DEBUG-KV] Main Head Path: Accessing layer %d\n", il); + // } + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); + + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + // cb(cur, "result_norm", -1); + res->t_embd = cur; + + // Use the main model header + res->t_logits = build_lora_mm(model.output, cur); } - ggml_build_forward_expand(gf, res->t_logits); + ggml_build_forward_expand(gf, res->t_logits); + } private: @@ -13956,6 +13992,10 @@ private: int64_t n_embd_head ) { const int il = hparams.n_layer - 1; + // LLAMA_LOG_WARN("[DEBUG-KV] MTP Head Path: Accessing layer %d\n", il); + ggml_tensor * sum_node = ggml_sum(ctx0, prev_embeddings); + + ggml_set_name(sum_node, "mtp_input_sum"); ggml_tensor * inp_pos = build_inp_pos(); auto * inp_attn = build_attn_inp_kv_unified(); @@ -14015,7 +14055,11 @@ private: cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - + // LLAMA_LOG_WARN("[DEBUG-MTP-ATTN] Inputs for build_attn in the layer %d:\n", il); + // LLAMA_LOG_WARN(" - Qcur shape: [%d, %d, %d]\n", Qcur->ne[0], Qcur->ne[1], Qcur->ne[2]); + // LLAMA_LOG_WARN(" - Kcur shape: [%d, %d, %d]\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2]); + // LLAMA_LOG_WARN(" - Vcur shape: [%d, %d, %d]\n", Vcur->ne[0], Vcur->ne[1], Vcur->ne[2]); + cur = build_attn(inp_attn, mtp_layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); @@ -18511,7 +18555,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GLM4_MOE: { - llm = std::make_unique(*this, params, build_mtp); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BITNET: { diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 84a0e6fc15..7070a56159 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3387,6 +3387,15 @@ struct server_context { slot.n_prompt_tokens_processed += n_pos; } + const size_t n_to_log = slot.mtp_kv_update_batch.size(); + if (n_to_log > 0) { + SLT_INF(slot, + "DEBUG-KV-REQ Cache Warm-up: Requesting KV update for %zu tokens. Positions: %d ... %d\n", + n_to_log, + slot.mtp_kv_update_batch.front().n_past, + slot.mtp_kv_update_batch.back().n_past + ); + } // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { // get next token to process @@ -3517,12 +3526,12 @@ struct server_context { continue; // continue loop of n_batch } - for (auto & slot : slots) { - // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation - if (slot.has_mtp) { - mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens); - } - } + // for (auto & slot : slots) { + // // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation + // if (slot.has_mtp) { + // mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens); + // } + // } // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens;