This commit is contained in:
Xuan Son Nguyen 2026-02-19 20:06:09 +01:00
parent c256da1f9f
commit fc36eb7700
7 changed files with 162 additions and 24 deletions

View File

@ -464,6 +464,127 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
}
};
struct common_speculative_state_nextn : public common_speculative_state {
llama_context * ctx_tgt; // used for copying state from tgt --> dft
llama_context * ctx_dft;
common_sampler * smpl;
llama_batch batch;
llama_tokens prompt_dft;
bool vocab_cmpt = true; // whether retokenization is needed
std::unordered_map<std::string, std::string> vocab_map;
common_speculative_state_nextn(
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_dft)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
{
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
{
common_params_sampling params;
params.no_perf = false;
params.top_k = 10;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};
smpl = common_sampler_init(llama_get_model(ctx_dft), params);
}
}
~common_speculative_state_nextn() override {
llama_perf_context_print(ctx_dft);
common_sampler_free(smpl);
llama_batch_free(batch);
}
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
auto * spec = this;
auto & batch = spec->batch;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft = spec->ctx_dft;
auto & smpl = spec->smpl;
auto & prompt_dft = spec->prompt_dft;
auto * mem_dft = llama_get_memory(ctx_dft);
result.clear();
result.reserve(params.n_max);
llama_memory_clear(mem_dft, false);
common_sampler_reset(smpl);
llama_mtp_start(ctx_tgt, ctx_dft); // copy state from main LLM to draft
// decode first token
int n_past = 0;
common_batch_clear(batch);
common_batch_add(batch, id_last, n_past++, { 0 }, true);
llama_decode(ctx_dft, batch);
common_sampler_accept(smpl, id_last, true);
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_max; ++i) {
// printf("drafting token %d\n", i);
common_batch_clear(batch);
common_sampler_sample(smpl, ctx_dft, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
common_sampler_accept(smpl, id, true);
result.push_back(id);
if (params.n_max <= (int) result.size()) {
break;
}
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
break;
}
common_batch_add(batch, id, n_past++, { 0 }, true);
// evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch);
prompt_dft.push_back(id);
}
}
void accept(uint16_t n_accepted) override {
// noop
GGML_UNUSED(n_accepted);
// printf("\n\n%s: accepted %d tokens\n\n", __func__, n_accepted);
}
};
// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_state {
common_ngram_simple_config config;
@ -855,6 +976,7 @@ common_speculative * common_speculative_init(
{
bool has_draft = !params.mparams_dft.path.empty();
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
bool has_draft_nextn = (params.type == COMMON_SPECULATIVE_TYPE_NEXTN);
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@ -900,6 +1022,9 @@ common_speculative * common_speculative_init(
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
}
if (has_draft_nextn) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NEXTN, params));
}
}
std::vector<std::unique_ptr<common_speculative_state>> impls = {};
@ -921,6 +1046,13 @@ common_speculative * common_speculative_init(
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
break;
}
case COMMON_SPECULATIVE_TYPE_NEXTN: {
impls.push_back(std::make_unique<common_speculative_state_nextn>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft
));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config);

View File

@ -2721,12 +2721,12 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
};
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}

View File

@ -800,6 +800,7 @@ int32_t llama_context::cpy_mtp_state(llama_context & ctx_mtp) {
}
// TODO: maybe std::move is better?
LLAMA_LOG_DEBUG("%s: copying MTP state (n_token = %lld, n_embd = %lld)\n", __func__, cross.n_token, cross.n_embd);
ctx_mtp.cross = cross;
return 0;
@ -1595,12 +1596,14 @@ int llama_context::decode(const llama_batch & batch_inp) {
break;
}
const bool update_mtp_state = hparams.nextn_predict_layers > 0 && n_outputs > 0;
const bool update_mtp_state = gtype == LLM_GRAPH_TYPE_DECODER_MTP // this is MTP layer
|| (hparams.nextn_predict_layers > 0 && n_outputs_all > 0); // or, this is the main LLM, we need to forward state to MTP layer
// set MTP state if needed
if (update_mtp_state) {
// printf("\n\nupdate MTP state: gtype = %d, n_outputs_all = %d\n", (int) gtype, n_outputs_all);
cross.n_embd = hparams.get_n_embd_mtp();
cross.n_token = n_outputs;
cross.n_token = n_outputs_all;
cross.mtp_embd.resize(cross.n_embd*cross.n_token);
}

View File

@ -1788,9 +1788,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// NextN/MTP parameters (GLM-OCR)
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
switch (hparams.n_layer) {
case 17: type = LLM_TYPE_1B; break; // GLM-OCR
case 40: type = LLM_TYPE_9B; break;
@ -1821,9 +1818,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// NextN/MTP parameters
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
switch (hparams.n_layer) {
case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open
@ -5475,10 +5469,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
for (int i = 0; i < n_layer; ++i) {
int flags = 0;
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
// skip all tensors in the NextN layers
flags |= TENSOR_SKIP;
}
auto & layer = layers[i];
@ -5505,7 +5495,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags);
// NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers
// NextN/MTP tensors
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags);

View File

@ -81,6 +81,9 @@ llm_build_glm4_moe<LLM_GRAPH_TYPE_DECODER_MTP>::llm_build_glm4_moe(const llama_m
inpL = ggml_concat(ctx0, inp_token_embd, inp_state_embd, 0);
cb(inpL, "inp_mtp", il);
inpL = build_lora_mm(mtp_layer.nextn.eh_proj, inpL);
cb(inpL, "inp_mtp_projected", il);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
@ -88,8 +91,7 @@ llm_build_glm4_moe<LLM_GRAPH_TYPE_DECODER_MTP>::llm_build_glm4_moe(const llama_m
ggml_tensor * inp_out_ids = build_inp_out_ids();
{
// input for next layer
bool is_output_layer = (il == n_layer - 1);
bool is_output_layer = true; // TODO: we only have one single nextn layer for now, may need to change in the future
inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il);
}
cur = inpL;

View File

@ -81,6 +81,9 @@ llm_build_glm4<LLM_GRAPH_TYPE_DECODER_MTP>::llm_build_glm4(const llama_model & m
inpL = ggml_concat(ctx0, inp_token_embd, inp_state_embd, 0);
cb(inpL, "inp_mtp", il);
inpL = build_lora_mm(mtp_layer.nextn.eh_proj, inpL);
cb(inpL, "inp_mtp_projected", il);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
@ -88,8 +91,7 @@ llm_build_glm4<LLM_GRAPH_TYPE_DECODER_MTP>::llm_build_glm4(const llama_model & m
ggml_tensor * inp_out_ids = build_inp_out_ids();
{
// input for next layer
bool is_output_layer = (il == n_layer - 1);
bool is_output_layer = true; // TODO: we only have one single nextn layer for now, may need to change in the future
inpL = build_layer(model, inp_attn, inpL, inp_pos, inp_out_ids, sections, is_output_layer, il);
}
cur = inpL;

View File

@ -639,7 +639,8 @@ private:
add_bos_token = llama_vocab_get_add_bos(vocab);
if (params_base.speculative.has_dft()) {
//if (params_base.speculative.has_dft()) {
{
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
const auto & params_spec = params_base.speculative;
@ -662,6 +663,7 @@ private:
params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides;
/*
auto mparams_dft = common_model_params_to_llama(params_dft);
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
@ -672,6 +674,13 @@ private:
params_base.speculative.model_dft = model_dft.get();
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
*/
// FOR TESTING ONLY!!!!!!
params_base.speculative.type = COMMON_SPECULATIVE_TYPE_NEXTN;
params_base.speculative.model_dft = model;
params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft);
params_base.speculative.cparams_dft.graph_type = LLAMA_GRAPH_TYPE_DECODER_MTP;
}
std::string & mmproj_path = params_base.mmproj.path;