some work towards building mtp layer graph

This commit is contained in:
Aaron Lee 2025-08-11 01:21:47 -04:00
parent db60623e79
commit e434f87cc7
2 changed files with 149 additions and 8 deletions

View File

@ -4507,6 +4507,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// but only PROCESS up to last layer (skipping final NextN layer) in forward pass
for (int i = 0; i < n_layer; ++i) {
int flags = 0;
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
// skip all tensors in the NextN layers
flags |= TENSOR_SKIP;
@ -13919,6 +13920,144 @@ struct llm_build_glm4_moe : public llm_graph_context {
}
};
struct llm_build_glm4_moe_mtp : public llm_graph_context {
llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params,
// For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past
) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
// Assuming a single MTP layer at the end
const int il = hparams.n_layer - 1;
const auto & mtp_layer = model.layers[il];
ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
ggml_set_i32(inp_pos, n_past);
llm_graph_input_attn_no_cache * inp_attn = nullptr;
ggml_tensor * cur;
// get MTP embedding for last (conventionally sampled) token
ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
ggml_set_i32(inp_token_id, last_token_id);
ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
// vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states)
ggml_tensor * hidden_state_norm = build_norm(hidden_state_inp, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat
cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj
// now proceed through last layer (skipped in main model)
ggml_tensor * inpSA = cur;
// Pre-attention norm for the MTP block
ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
// self-attention
{
ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur);
if (mtp_layer.bq) {
Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
}
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur);
if (mtp_layer.bk) {
Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
}
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur);
if (mtp_layer.bv) {
Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
}
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// Apply Q/K norm if available (GLM-4.5 355B variant)
if (mtp_layer.attn_q_norm) {
Qcur = build_norm(Qcur, mtp_layer.attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
}
if (mtp_layer.attn_k_norm) {
Kcur = build_norm(Kcur, mtp_layer.attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Kcur_normed", il);
}
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
mtp_layer.wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il);
// moe ffn for nextn block
{
// Process routed experts using existing MoE infrastructure
ggml_tensor * routed_out = build_moe_ffn(cur,
mtp_layer.ffn_gate_inp,
mtp_layer.ffn_up_exps,
mtp_layer.ffn_gate_exps,
mtp_layer.ffn_down_exps,
mtp_layer.ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(routed_out, "ffn_moe_out", il);
// Process shared expert on original input
ggml_tensor * shared_out = build_ffn(cur,
mtp_layer.ffn_up_shexp, NULL, NULL,
mtp_layer.ffn_gate_shexp, NULL, NULL,
mtp_layer.ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(shared_out, "ffn_shexp_out", il);
// Final output: routed_output + shared_output
cur = ggml_add(ctx0, routed_out, shared_out);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
res->t_logits = cur;
ggml_build_forward_expand(gf, res->t_logits);
}
};
struct llm_build_nemotron : public llm_graph_context {
llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;

View File

@ -1432,7 +1432,7 @@ struct server_slot {
}
bool can_speculate() const {
return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt;
}
void add_token(const completion_token_output & token) {
@ -2122,14 +2122,16 @@ struct server_context {
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
}
}
// if model has MTP and no draft model is specified...
else if (llama_model_n_nextn_layer(model) > 0) {
SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model));
slot.has_mtp = true;
// assume one speculative token (true of all well-known MTP models so far)
slot.batch_spec = llama_batch_init(2, 0, 1);
params_base.speculative.n_min = 0;
params_base.speculative.n_max = 1;
SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model));
slot.has_mtp = true;
// assume one speculative token (true of all well-known MTP models so far)
slot.batch_spec = llama_batch_init(2, 0, 1);
params_base.speculative.n_min = 0;
params_base.speculative.n_max = 1;
}
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);