some work towards building mtp layer graph
This commit is contained in:
parent
db60623e79
commit
e434f87cc7
|
|
@ -4507,6 +4507,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
// but only PROCESS up to last layer (skipping final NextN layer) in forward pass
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
int flags = 0;
|
||||
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
// skip all tensors in the NextN layers
|
||||
flags |= TENSOR_SKIP;
|
||||
|
|
@ -13919,6 +13920,144 @@ struct llm_build_glm4_moe : public llm_graph_context {
|
|||
}
|
||||
};
|
||||
|
||||
struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
||||
llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params,
|
||||
// For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
|
||||
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past
|
||||
) : llm_graph_context(params) {
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
// Assuming a single MTP layer at the end
|
||||
const int il = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il];
|
||||
|
||||
ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
|
||||
ggml_set_i32(inp_pos, n_past);
|
||||
llm_graph_input_attn_no_cache * inp_attn = nullptr;
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
// get MTP embedding for last (conventionally sampled) token
|
||||
ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
|
||||
ggml_set_i32(inp_token_id, last_token_id);
|
||||
ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
|
||||
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
// vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
ggml_tensor * hidden_state_norm = build_norm(hidden_state_inp, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat
|
||||
cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj
|
||||
|
||||
|
||||
// now proceed through last layer (skipped in main model)
|
||||
ggml_tensor * inpSA = cur;
|
||||
|
||||
// Pre-attention norm for the MTP block
|
||||
ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur);
|
||||
if (mtp_layer.bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur);
|
||||
if (mtp_layer.bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
|
||||
}
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur);
|
||||
if (mtp_layer.bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
|
||||
}
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Apply Q/K norm if available (GLM-4.5 355B variant)
|
||||
if (mtp_layer.attn_q_norm) {
|
||||
Qcur = build_norm(Qcur, mtp_layer.attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
}
|
||||
if (mtp_layer.attn_k_norm) {
|
||||
Kcur = build_norm(Kcur, mtp_layer.attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
mtp_layer.wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
|
||||
cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
// moe ffn for nextn block
|
||||
{
|
||||
// Process routed experts using existing MoE infrastructure
|
||||
ggml_tensor * routed_out = build_moe_ffn(cur,
|
||||
mtp_layer.ffn_gate_inp,
|
||||
mtp_layer.ffn_up_exps,
|
||||
mtp_layer.ffn_gate_exps,
|
||||
mtp_layer.ffn_down_exps,
|
||||
mtp_layer.ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(routed_out, "ffn_moe_out", il);
|
||||
|
||||
// Process shared expert on original input
|
||||
ggml_tensor * shared_out = build_ffn(cur,
|
||||
mtp_layer.ffn_up_shexp, NULL, NULL,
|
||||
mtp_layer.ffn_gate_shexp, NULL, NULL,
|
||||
mtp_layer.ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(shared_out, "ffn_shexp_out", il);
|
||||
|
||||
// Final output: routed_output + shared_output
|
||||
cur = ggml_add(ctx0, routed_out, shared_out);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
|
||||
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
|
||||
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, res->t_logits);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_nemotron : public llm_graph_context {
|
||||
llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
|
|
|||
|
|
@ -1432,7 +1432,7 @@ struct server_slot {
|
|||
}
|
||||
|
||||
bool can_speculate() const {
|
||||
return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
|
||||
return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt;
|
||||
}
|
||||
|
||||
void add_token(const completion_token_output & token) {
|
||||
|
|
@ -2122,14 +2122,16 @@ struct server_context {
|
|||
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// if model has MTP and no draft model is specified...
|
||||
else if (llama_model_n_nextn_layer(model) > 0) {
|
||||
SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model));
|
||||
slot.has_mtp = true;
|
||||
|
||||
// assume one speculative token (true of all well-known MTP models so far)
|
||||
slot.batch_spec = llama_batch_init(2, 0, 1);
|
||||
params_base.speculative.n_min = 0;
|
||||
params_base.speculative.n_max = 1;
|
||||
SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model));
|
||||
slot.has_mtp = true;
|
||||
|
||||
// assume one speculative token (true of all well-known MTP models so far)
|
||||
slot.batch_spec = llama_batch_init(2, 0, 1);
|
||||
params_base.speculative.n_min = 0;
|
||||
params_base.speculative.n_max = 1;
|
||||
}
|
||||
|
||||
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
||||
|
|
|
|||
Loading…
Reference in New Issue