Merge pull request #1 from SamuelOliveirads/glm4-moe-mtp

feat: implemented sampling for MTP
This commit is contained in:
Aaron Lee 2025-09-13 02:57:01 -04:00 committed by GitHub
commit c6237c71ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 61 additions and 83 deletions

View File

@ -582,3 +582,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
return samplers;
}
void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) {
llama_sampler_apply(gsmpl->chain, cur_p);
}

View File

@ -105,3 +105,5 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data);
void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p);

View File

@ -370,56 +370,35 @@ llama_token mtp_speculative_gen_draft(
int32_t n_past,
int32_t last_tok_idx) {
llama_token token_data[] = { id_last };
llama_pos pos_data[] = { n_past };
int32_t n_seq_id_data[] = { 1 };
llama_seq_id seq_id_data_internal[] = { 0 };
llama_seq_id* seq_id_data[] = {seq_id_data_internal};
int8_t logits_data[] = { (int8_t) (smpl != nullptr) };
llama_batch batch = {
/*.n_tokens = */ 1,
/*.token = */ token_data,
/*.embd = */ nullptr,
/*.pos = */ pos_data,
/*.n_seq_id = */ n_seq_id_data,
/*.seq_id = */ seq_id_data,
/*.logits = */ logits_data
};
return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
//LOG_INF("updating kv cache for n_past: %d\n", n_past);
/*
if (!smpl) {
return -1;
}
else {
common_sampler_sample(smpl, ctx, last_tok_idx, true);
const auto* cur_p = common_sampler_get_candidates(smpl);
//for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) {
// LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
// k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
//}
llama_batch batch = llama_batch_init(1, 0, 1);
common_batch_add(batch, id_last, n_past, {0}, true);
const llama_token id = cur_p->data[0].id;
return id;
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);
cur_p->size = n_vocab;
for (int i = 0; i < n_vocab; ++i) {
cur_p->data[i].id = i;
cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i];
}
*/
// LOG_INF("cur_p->size: %d\n", cur_p->size);
cur_p->sorted = false;
common_sampler_apply_chain(smpl, cur_p);
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
// smpl will accept the token if it doesn't get rejected by main model later
// common_sampler_accept(smpl, id, true);
llama_batch_free(batch);
//llama_tokens result;
//result.reserve(1);
//result.push_back(id);
//return result;
return id;
}
@ -438,4 +417,4 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_d
}
tokens.clear();
}
}

View File

@ -44,9 +44,9 @@ llama_token mtp_speculative_gen_draft(
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start = 0, size_t n_tokens = -1);

View File

@ -1454,8 +1454,8 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
#ifdef __cplusplus
}

View File

@ -2995,7 +2995,7 @@ void llama_opt_epoch(
callback_eval);
}
llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
const auto * model = llama_get_model(ctx);
@ -3033,6 +3033,12 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);
if (!gf) {
LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__);
if (sched) ggml_backend_sched_free(sched);
return;
}
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
@ -3044,29 +3050,24 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
//struct ggml_tensor * logits_mtp = res_mtp->get_logits();
struct ggml_tensor * logits_mtp = res_mtp->get_logits();
//LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
//if (logits_mtp) {
// ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
//}
struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result");
llama_token token_id = 0; // The C++ variable to hold the result.
// ggml_backend_tensor_get is the function for GPU->CPU copies.
// We are copying a single 32-bit integer.
ggml_backend_tensor_get(
token_id_tensor,
&token_id, // Pointer to our C++ variable
0, // Starting offset in bytes
sizeof(llama_token) // Number of bytes to copy
);
if (logits_mtp) {
float * logits_dest = ctx->get_logits_ith(last_tok_idx);
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp);
if (backend_res) {
// ggml_backend_tensor_get is the function for GPU->CPU copies.
// We are copying a single 32-bit integer.
ggml_backend_tensor_get(logits_mtp,
logits_dest, // Pointer to our C++ variable
0, // Starting offset in bytes
ggml_nbytes(logits_mtp)); // Number of bytes to copy
} else {
LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__);
}
} else {
LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__);
}
ggml_backend_sched_free(sched);
return token_id;
}
}

View File

@ -13950,6 +13950,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
// For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
llama_token last_token_id, int n_past
) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@ -13964,8 +13965,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
//llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr;
auto * inp_attn = build_attn_inp_kv_unified();
ggml_tensor * cur;
// get MTP embedding for last (conventionally sampled) token
// ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
// LLAMA_LOG_INFO("step: '%d'\n", 5641);
@ -13979,7 +13978,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
//ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id);
//ggml_set_no_alloc(ctx0, true);
ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
@ -13994,9 +13993,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat
cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj
ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj
// now proceed through last layer (skipped in main model)
ggml_tensor * inpSA = cur;
@ -14096,14 +14093,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
res->t_logits = cur;
ggml_build_forward_expand(gf, res->t_logits);
struct ggml_tensor * token_id_tensor = ggml_argmax(ctx0, cur);
ggml_set_name(token_id_tensor, "mtp_argmax_result");
ggml_build_forward_expand(gf, token_id_tensor);
}
};