replace standard sampler with greedy sampler for mtp draft

This commit is contained in:
Aaron Lee 2025-08-26 01:26:51 -04:00
parent 471e026327
commit 98bc0c6bf2
4 changed files with 29 additions and 7 deletions

View File

@ -387,9 +387,10 @@ llama_token mtp_speculative_gen_draft(
/*.logits = */ logits_data
};
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
//LOG_INF("updating kv cache for n_past: %d\n", n_past);
/*
if (!smpl) {
return -1;
}
@ -405,6 +406,7 @@ llama_token mtp_speculative_gen_draft(
const llama_token id = cur_p->data[0].id;
return id;
}
*/
// LOG_INF("cur_p->size: %d\n", cur_p->size);

View File

@ -1454,7 +1454,7 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
#ifdef __cplusplus

View File

@ -2995,7 +2995,7 @@ void llama_opt_epoch(
callback_eval);
}
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
const auto * model = llama_get_model(ctx);
@ -3044,13 +3044,29 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
struct ggml_tensor * logits_mtp = res_mtp->get_logits();;
//struct ggml_tensor * logits_mtp = res_mtp->get_logits();
//LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
if (logits_mtp) {
ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
}
//if (logits_mtp) {
// ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
//}
struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result");
llama_token token_id = 0; // The C++ variable to hold the result.
// ggml_backend_tensor_get is the function for GPU->CPU copies.
// We are copying a single 32-bit integer.
ggml_backend_tensor_get(
token_id_tensor,
&token_id, // Pointer to our C++ variable
0, // Starting offset in bytes
sizeof(llama_token) // Number of bytes to copy
);
ggml_backend_sched_free(sched);
return token_id;
}

View File

@ -14100,6 +14100,10 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
res->t_logits = cur;
ggml_build_forward_expand(gf, res->t_logits);
struct ggml_tensor * token_id_tensor = ggml_argmax(ctx0, cur);
ggml_set_name(token_id_tensor, "mtp_argmax_result");
ggml_build_forward_expand(gf, token_id_tensor);
}
};