From fc36eb7700be1e57377f4beda5d234c5ad851c99 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 19 Feb 2026 20:06:09 +0100 Subject: [PATCH] wip --- common/speculative.cpp | 132 ++++++++++++++++++++++++++++++++ src/llama-arch.cpp | 12 +-- src/llama-context.cpp | 7 +- src/llama-model.cpp | 12 +-- src/models/glm4-moe.cpp | 6 +- src/models/glm4.cpp | 6 +- tools/server/server-context.cpp | 11 ++- 7 files changed, 162 insertions(+), 24 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 1592ec8412..05fb4e0333 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -464,6 +464,127 @@ struct common_speculative_state_eagle3 : public common_speculative_state { } }; +struct common_speculative_state_nextn : public common_speculative_state { + llama_context * ctx_tgt; // used for copying state from tgt --> dft + llama_context * ctx_dft; + + common_sampler * smpl; + + llama_batch batch; + llama_tokens prompt_dft; + + bool vocab_cmpt = true; // whether retokenization is needed + std::unordered_map vocab_map; + + common_speculative_state_nextn( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft(ctx_dft) + { + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); + + { + common_params_sampling params; + params.no_perf = false; + params.top_k = 10; + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + + smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } + } + + ~common_speculative_state_nextn() override { + llama_perf_context_print(ctx_dft); + + common_sampler_free(smpl); + + llama_batch_free(batch); + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + auto * spec = this; + + auto & batch = spec->batch; + auto & ctx_tgt = spec->ctx_tgt; + auto & ctx_dft = spec->ctx_dft; + auto & smpl = spec->smpl; + auto & prompt_dft = spec->prompt_dft; + + auto * mem_dft = llama_get_memory(ctx_dft); + + result.clear(); + result.reserve(params.n_max); + + llama_memory_clear(mem_dft, false); + common_sampler_reset(smpl); + + llama_mtp_start(ctx_tgt, ctx_dft); // copy state from main LLM to draft + + // decode first token + int n_past = 0; + common_batch_clear(batch); + common_batch_add(batch, id_last, n_past++, { 0 }, true); + llama_decode(ctx_dft, batch); + common_sampler_accept(smpl, id_last, true); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_max; ++i) { + // printf("drafting token %d\n", i); + common_batch_clear(batch); + common_sampler_sample(smpl, ctx_dft, 0, true); + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, id, true); + + result.push_back(id); + + if (params.n_max <= (int) result.size()) { + break; + } + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + break; + } + + common_batch_add(batch, id, n_past++, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(ctx_dft, batch); + + prompt_dft.push_back(id); + } + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + // printf("\n\n%s: accepted %d tokens\n\n", __func__, n_accepted); + } +}; + // state of self-speculation (simple implementation, not ngram-map) struct common_speculative_state_ngram_simple : public common_speculative_state { common_ngram_simple_config config; @@ -855,6 +976,7 @@ common_speculative * common_speculative_init( { bool has_draft = !params.mparams_dft.path.empty(); bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 + bool has_draft_nextn = (params.type == COMMON_SPECULATIVE_TYPE_NEXTN); bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); @@ -900,6 +1022,9 @@ common_speculative * common_speculative_init( if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); } + if (has_draft_nextn) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NEXTN, params)); + } } std::vector> impls = {}; @@ -921,6 +1046,13 @@ common_speculative * common_speculative_init( impls.push_back(std::make_unique(config.type)); break; } + case COMMON_SPECULATIVE_TYPE_NEXTN: { + impls.push_back(std::make_unique(config.type, + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft + )); + break; + } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { common_ngram_map ngram_map = get_common_ngram_map(config); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 3cb45b6922..688c24de2f 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2721,12 +2721,12 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 99119242be..58a887a790 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -800,6 +800,7 @@ int32_t llama_context::cpy_mtp_state(llama_context & ctx_mtp) { } // TODO: maybe std::move is better? + LLAMA_LOG_DEBUG("%s: copying MTP state (n_token = %lld, n_embd = %lld)\n", __func__, cross.n_token, cross.n_embd); ctx_mtp.cross = cross; return 0; @@ -1595,12 +1596,14 @@ int llama_context::decode(const llama_batch & batch_inp) { break; } - const bool update_mtp_state = hparams.nextn_predict_layers > 0 && n_outputs > 0; + const bool update_mtp_state = gtype == LLM_GRAPH_TYPE_DECODER_MTP // this is MTP layer + || (hparams.nextn_predict_layers > 0 && n_outputs_all > 0); // or, this is the main LLM, we need to forward state to MTP layer // set MTP state if needed if (update_mtp_state) { + // printf("\n\nupdate MTP state: gtype = %d, n_outputs_all = %d\n", (int) gtype, n_outputs_all); cross.n_embd = hparams.get_n_embd_mtp(); - cross.n_token = n_outputs; + cross.n_token = n_outputs_all; cross.mtp_embd.resize(cross.n_embd*cross.n_token); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 01da83da14..77926522d7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1788,9 +1788,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters (GLM-OCR) ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - switch (hparams.n_layer) { case 17: type = LLM_TYPE_1B; break; // GLM-OCR case 40: type = LLM_TYPE_9B; break; @@ -1821,9 +1818,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - switch (hparams.n_layer) { case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open @@ -5475,10 +5469,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } auto & layer = layers[i]; @@ -5505,7 +5495,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + // NextN/MTP tensors if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); diff --git a/src/models/glm4-moe.cpp b/src/models/glm4-moe.cpp index d46723dfe5..11c4b675c0 100644 --- a/src/models/glm4-moe.cpp +++ b/src/models/glm4-moe.cpp @@ -81,6 +81,9 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_m inpL = ggml_concat(ctx0, inp_token_embd, inp_state_embd, 0); cb(inpL, "inp_mtp", il); + inpL = build_lora_mm(mtp_layer.nextn.eh_proj, inpL); + cb(inpL, "inp_mtp_projected", il); + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -88,8 +91,7 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_m ggml_tensor * inp_out_ids = build_inp_out_ids(); { - // input for next layer - bool is_output_layer = (il == n_layer - 1); + bool is_output_layer = true; // TODO: we only have one single nextn layer for now, may need to change in the future inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il); } cur = inpL; diff --git a/src/models/glm4.cpp b/src/models/glm4.cpp index dba2c8a763..d4a1b3fad6 100644 --- a/src/models/glm4.cpp +++ b/src/models/glm4.cpp @@ -81,6 +81,9 @@ llm_build_glm4::llm_build_glm4(const llama_model & m inpL = ggml_concat(ctx0, inp_token_embd, inp_state_embd, 0); cb(inpL, "inp_mtp", il); + inpL = build_lora_mm(mtp_layer.nextn.eh_proj, inpL); + cb(inpL, "inp_mtp_projected", il); + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -88,8 +91,7 @@ llm_build_glm4::llm_build_glm4(const llama_model & m ggml_tensor * inp_out_ids = build_inp_out_ids(); { - // input for next layer - bool is_output_layer = (il == n_layer - 1); + bool is_output_layer = true; // TODO: we only have one single nextn layer for now, may need to change in the future inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il); } cur = inpL; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 8aab0d4c1b..41104ccdd8 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -639,7 +639,8 @@ private: add_bos_token = llama_vocab_get_add_bos(vocab); - if (params_base.speculative.has_dft()) { + //if (params_base.speculative.has_dft()) { + { SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str()); const auto & params_spec = params_base.speculative; @@ -662,6 +663,7 @@ private: params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides; + /* auto mparams_dft = common_model_params_to_llama(params_dft); model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); @@ -672,6 +674,13 @@ private: params_base.speculative.model_dft = model_dft.get(); params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft); + */ + + // FOR TESTING ONLY!!!!!! + params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NEXTN; + params_base.speculative.model_dft = model; + params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft); + params_base.speculative.cparams_dft.graph_type = LLAMA_GRAPH_TYPE_DECODER_MTP; } std::string & mmproj_path = params_base.mmproj.path;