diff --git a/common/speculative.cpp b/common/speculative.cpp index 9f8384abb1..edeffe2d8e 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -387,9 +387,10 @@ llama_token mtp_speculative_gen_draft( /*.logits = */ logits_data }; - llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); + return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); //LOG_INF("updating kv cache for n_past: %d\n", n_past); + /* if (!smpl) { return -1; } @@ -405,6 +406,7 @@ llama_token mtp_speculative_gen_draft( const llama_token id = cur_p->data[0].id; return id; } + */ // LOG_INF("cur_p->size: %d\n", cur_p->size); diff --git a/include/llama.h b/include/llama.h index 1de8a963cc..015c777763 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1454,7 +1454,7 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); #ifdef __cplusplus diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 62d7898b5f..1f04b72145 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2995,7 +2995,7 @@ void llama_opt_epoch( callback_eval); } -void llama_build_and_execute_mtp_graph(struct llama_context * ctx, +llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { const auto * model = llama_get_model(ctx); @@ -3044,13 +3044,29 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx, ggml_backend_sched_graph_compute(sched, gf); // execute the graph - struct ggml_tensor * logits_mtp = res_mtp->get_logits();; + //struct ggml_tensor * logits_mtp = res_mtp->get_logits(); + //LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); - if (logits_mtp) { - ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); - } + //if (logits_mtp) { + // ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); + //} + struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result"); + + + llama_token token_id = 0; // The C++ variable to hold the result. + + // ggml_backend_tensor_get is the function for GPU->CPU copies. + // We are copying a single 32-bit integer. + ggml_backend_tensor_get( + token_id_tensor, + &token_id, // Pointer to our C++ variable + 0, // Starting offset in bytes + sizeof(llama_token) // Number of bytes to copy + ); ggml_backend_sched_free(sched); + + return token_id; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 04743e01f3..f9921e4b6d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -14100,6 +14100,10 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { res->t_logits = cur; ggml_build_forward_expand(gf, res->t_logits); + + struct ggml_tensor * token_id_tensor = ggml_argmax(ctx0, cur); + ggml_set_name(token_id_tensor, "mtp_argmax_result"); + ggml_build_forward_expand(gf, token_id_tensor); } };