replace standard sampler with greedy sampler for mtp draft
This commit is contained in:
parent
471e026327
commit
98bc0c6bf2
|
|
@ -387,9 +387,10 @@ llama_token mtp_speculative_gen_draft(
|
|||
/*.logits = */ logits_data
|
||||
};
|
||||
|
||||
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
|
||||
return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
|
||||
//LOG_INF("updating kv cache for n_past: %d\n", n_past);
|
||||
|
||||
/*
|
||||
if (!smpl) {
|
||||
return -1;
|
||||
}
|
||||
|
|
@ -405,6 +406,7 @@ llama_token mtp_speculative_gen_draft(
|
|||
const llama_token id = cur_p->data[0].id;
|
||||
return id;
|
||||
}
|
||||
*/
|
||||
// LOG_INF("cur_p->size: %d\n", cur_p->size);
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1454,7 +1454,7 @@ extern "C" {
|
|||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
|||
|
|
@ -2995,7 +2995,7 @@ void llama_opt_epoch(
|
|||
callback_eval);
|
||||
}
|
||||
|
||||
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
|
||||
|
||||
const auto * model = llama_get_model(ctx);
|
||||
|
|
@ -3044,13 +3044,29 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
|||
|
||||
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
|
||||
|
||||
struct ggml_tensor * logits_mtp = res_mtp->get_logits();;
|
||||
//struct ggml_tensor * logits_mtp = res_mtp->get_logits();
|
||||
|
||||
//LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
|
||||
|
||||
if (logits_mtp) {
|
||||
ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
|
||||
}
|
||||
//if (logits_mtp) {
|
||||
// ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
|
||||
//}
|
||||
struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result");
|
||||
|
||||
|
||||
llama_token token_id = 0; // The C++ variable to hold the result.
|
||||
|
||||
// ggml_backend_tensor_get is the function for GPU->CPU copies.
|
||||
// We are copying a single 32-bit integer.
|
||||
ggml_backend_tensor_get(
|
||||
token_id_tensor,
|
||||
&token_id, // Pointer to our C++ variable
|
||||
0, // Starting offset in bytes
|
||||
sizeof(llama_token) // Number of bytes to copy
|
||||
);
|
||||
|
||||
ggml_backend_sched_free(sched);
|
||||
|
||||
return token_id;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14100,6 +14100,10 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
|||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, res->t_logits);
|
||||
|
||||
struct ggml_tensor * token_id_tensor = ggml_argmax(ctx0, cur);
|
||||
ggml_set_name(token_id_tensor, "mtp_argmax_result");
|
||||
ggml_build_forward_expand(gf, token_id_tensor);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue