From cf0f7c0448c2c1736588673114558e5829db7879 Mon Sep 17 00:00:00 2001 From: Aaron Lee Date: Wed, 13 Aug 2025 02:21:17 -0400 Subject: [PATCH] broad thrust of the mtp implementation --- common/speculative.cpp | 126 ++++++++++++++++++++++++++++++++++++++++ common/speculative.h | 9 +++ include/llama.h | 17 ++++++ src/llama-context.cpp | 59 +++++++++++++++++++ src/llama-context.h | 7 +++ src/llama-graph.cpp | 4 ++ src/llama-graph.h | 1 + src/llama-model.cpp | 12 +++- tools/server/server.cpp | 36 ++++++++---- 9 files changed, 260 insertions(+), 11 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 262b2c23e7..e46a0968bd 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -5,6 +5,7 @@ #include "log.h" #include "common.h" #include "sampling.h" +#include "../src/llama-graph.h" #include #include @@ -359,3 +360,128 @@ llama_tokens common_speculative_gen_draft( } return result; } + + +llama_tokens mtp_speculative_gen_draft( + struct common_sampler * smpl, + struct llama_context * ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx) { + + llama_tokens result; + + LOG_INF("step: '%d'\n", 1); + + // sample one token from the draft model -- this does NOT generalize to >1 MTP head + result.reserve(1); + + // need to determine which architecture we're using so we call the correct MTP model + const auto * model = llama_get_model(ctx); + + LOG_INF("step: '%d'\n", 2); + + //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + //auto * gf = model.build_graph(gparams); + + LOG_INF("step: '%d'\n", 3); + + /*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + }*/ + + //llm_graph_result res_mtp(ctx->graph_max_nodes()); + llm_graph_result * res_mtp; + llama_ubatch ubatch_mtp; + ubatch_mtp.n_tokens = 1; + ubatch_mtp.pos = &n_past; // Critical for positional encoding + + // We also need a minimal ubatch to provide positional context (RoPE) + // ubatch_mtp.tokens = &last_token_id; + // ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper + // ubatch_mtp.logits = nullptr; + // ubatch_mtp.all_pos_0 = -1; + // ubatch_mtp.all_pos_1 = -1; + // ubatch_mtp.all_seq_id = -1; + + // Manually construct the graph parameters + //const llm_graph_params params_mtp = { + // /*.arch =*/ model->arch, + // /*.hparams =*/ model->hparams, + // /*.cparams =*/ ctx->cparams, + // /*.ubatch =*/ ubatch_mtp, + // /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, + // /*.sched =*/ ctx->sched.get(), + // /*.backend_cpu =*/ ctx->backend_cpu, + // /*.cvec =*/ &ctx->cvec, + // /*.loras =*/ &ctx->loras, + // /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context + // /*.cross =*/ &ctx->cross, + // /*.n_outputs =*/ 1, + // /*.cb =*/ ctx->graph_get_cb(), + // /*.res =*/ &res_mtp, // Point to our temporary result object + //}; + llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp); + + LOG_INF("step: '%d'\n", 4); + + // ggml_cgraph* build_mtp_graph(const llm_graph_params & params, + // ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; + auto * last_embd = llama_get_embeddings_tensor(ctx); + + LOG_INF("step: '%d'\n", 5); + + GGML_ASSERT(model != nullptr); + GGML_ASSERT(last_embd != nullptr); + + auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past); + + if (!gf) { + LOG_INF("%s: failed to initialize graph\n", __func__); + //ret = GGML_STATUS_FAILED; + return result; + } + + LOG_INF("step: '%d'\n", 6); + + const auto status = llama_graph_compute(ctx, gf, false); + + LOG_INF("step: '%d'\n", 7); + + struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp); + float * ctx_logit_pointer = llama_get_logits(ctx); + + LOG_INF("step: '%d'\n", 8); + + if (logits_mtp) { + llama_set_logits(ctx, logits_mtp); + } + + LOG_INF("step: '%d'\n", 9); + + { + common_sampler_sample(smpl, ctx, last_tok_idx, true); + + LOG_INF("step: '%d'\n", 10); + + const auto * cur_p = common_sampler_get_candidates(smpl); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // skip accepting draft token -- since we're only drafting one token this can't affect future outputs + // smpl will accept the token if it doesn't get rejected by main model later + // common_sampler_accept(smpl, id, true); + + result.push_back(id); + } + + return result; +} diff --git a/common/speculative.h b/common/speculative.h index e69d7aaa1e..3b04890073 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -27,6 +27,15 @@ void common_speculative_add_replacement_tgt_dft( struct common_speculative * spec, const char *source, const char *dest); + +// sample up to n_draft tokens and add them to the batch using the draft model +llama_tokens mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx); + // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_gen_draft( struct common_speculative * spec, diff --git a/include/llama.h b/include/llama.h index 3bade3ae71..2134f62d52 100644 --- a/include/llama.h +++ b/include/llama.h @@ -544,12 +544,17 @@ extern "C" { // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); + LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params, + struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, const char * fname_out, const llama_model_quantize_params * params); + + // // Adapters // @@ -972,6 +977,8 @@ extern "C" { // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + LLAMA_API void llama_set_logits(struct llama_context* ctx, struct ggml_tensor* logit_override); + // Get all output token embeddings. // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously @@ -994,6 +1001,8 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx); + // // Vocab // @@ -1452,6 +1461,14 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); + LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch); + + LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched); + + LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res); + + + #ifdef __cplusplus } #endif diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26a5cf9c3f..26c3e639d8 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,6 +6,7 @@ #include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-graph.h" #include #include @@ -522,6 +523,14 @@ float * llama_context::get_logits() { return logits; } +void llama_context::set_logits(struct ggml_tensor * logit_override) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), logit_override); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + ggml_backend_tensor_get_async(backend_res, logit_override, logits, 0, model.vocab.n_tokens() * sizeof(float)); +} + float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -617,6 +626,10 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +ggml_tensor * llama_context::get_embeddings_tensor() { + return embd_tensor; +} + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -1113,6 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) { auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + embd_tensor = res->get_embd(); if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1429,6 +1443,27 @@ llm_graph_params llama_context::graph_params( }; } +llm_graph_params llama_context::mtp_graph_params( + llm_graph_result* res, + const llama_ubatch& ubatch) const { + return { + /*.arch =*/ model.arch, + /*.hparams =*/ model.hparams, + /*.cparams =*/ cparams, + /*.ubatch =*/ ubatch, + /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, + /*.sched =*/ sched.get(), + /*.backend_cpu =*/ backend_cpu, + /*.cvec =*/ &cvec, + /*.loras =*/ &loras, + /*.mctx =*/ memory->init_batch(*balloc, 1, false).get(), + /*.cross =*/ &cross, + /*.n_outputs =*/ 1, + /*.cb =*/ graph_get_cb(), + /*.res =*/ res, + }; +} + ggml_status llama_context::graph_compute( ggml_cgraph * gf, bool batched) { @@ -2233,6 +2268,7 @@ void llama_context::opt_epoch( llama_batch_free(batch); } + // // interface implementation // @@ -2274,6 +2310,8 @@ llama_context_params llama_context_default_params() { return result; } + + llama_context * llama_init_from_model( llama_model * model, llama_context_params params) { @@ -2412,6 +2450,11 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } +void llama_set_logits(llama_context* ctx, struct ggml_tensor* logit_override) { + ctx->set_logits(logit_override); +} + + float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); @@ -2430,6 +2473,13 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_tensor(); +} + + // llama adapter API int32_t llama_set_adapter_lora( @@ -2926,3 +2976,12 @@ void llama_opt_epoch( callback_train, callback_eval); } + +llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) { + return ctx->mtp_graph_params(res, ubatch); +} + + +ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) { + return ctx->graph_compute(gf, batched); +} diff --git a/src/llama-context.h b/src/llama-context.h index 25c143d56d..44bcdf6d95 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -59,6 +59,7 @@ struct llama_context { float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + ggml_tensor * get_embeddings_tensor(); void attach_threadpool( ggml_threadpool_t threadpool, @@ -199,6 +200,10 @@ public: // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); + llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch) const; + + void set_logits(struct ggml_tensor* logit_override); + private: llm_graph_params graph_params( llm_graph_result * res, @@ -240,6 +245,7 @@ private: // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE size_t embd_size = 0; // capacity (of floats) for embeddings float * embd = nullptr; + ggml_tensor * embd_tensor = nullptr; // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE @@ -308,3 +314,4 @@ private: mutable int32_t n_reused = 0; // number of times the previous graph was reused }; + diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 053c72d6dc..b5184e4559 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1911,3 +1911,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck return relative_bucket; } + +ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) { + return res->get_logits(); +} diff --git a/src/llama-graph.h b/src/llama-graph.h index 6ff49de3a1..10702ed219 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -818,3 +818,4 @@ struct llm_graph_context { // TODO: better name int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); + diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 667d9e442b..8a9ba84803 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18673,19 +18673,21 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { return llm->res->get_gf(); } -ggml_cgraph* llama_model::build_mtp_graph(const llm_graph_params& params, +ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const { std::unique_ptr llm; switch (arch) { case LLM_ARCH_GLM4_MOE: { + printf("step: '%d'\n", 56); llm = std::make_unique(*this, params, hidden_state_inp, last_token_id, n_past); } break; default: GGML_ABORT("fatal error"); } + printf("step: '%d'\n", 57); return llm->res->get_gf(); } @@ -19004,3 +19006,11 @@ bool llama_model_is_diffusion(const llama_model * model) { const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } + +ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params, + ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) { + printf("step: '%d'\n", 55); + + return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past); +} + diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a9ad900ce3..29d551ea51 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1294,7 +1294,8 @@ struct server_slot { mtmd_context * mctx = nullptr; common_speculative * spec = nullptr; - bool has_mtp = false; + bool has_mtp = false; + int32_t last_tok_idx = -1; std::vector lora; @@ -1432,8 +1433,8 @@ struct server_slot { } bool can_speculate() const { - // return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; - return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; + return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; + // return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; } void add_token(const completion_token_output & token) { @@ -1993,7 +1994,7 @@ struct server_context { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - + vocab = llama_model_get_vocab(model); n_ctx = llama_n_ctx(ctx); @@ -3531,6 +3532,7 @@ struct server_context { const int tok_idx = slot.i_batch - i; llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + slot.last_tok_idx = tok_idx; slot.i_batch = -1; @@ -3567,6 +3569,8 @@ struct server_context { } } + SRV_DBG("starting speculative decoding: %d\n", 1); + // do speculative decoding for (auto & slot : slots) { if (!slot.is_processing() || !slot.can_speculate()) { @@ -3583,7 +3587,9 @@ struct server_context { } // determine the max draft that fits the current slot state + SLT_DBG(slot, "starting mtp draft: %d\n", 2); int n_draft_max = slot.params.speculative.n_max; + SLT_DBG(slot, "starting mtp draft: %d\n", 3); // note: n_past is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -3601,15 +3607,25 @@ struct server_context { continue; } + SLT_DBG(slot, "slot has mtp: %d\n", slot.has_mtp); + llama_token id = slot.sampled; - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; + llama_tokens draft; + if (slot.has_mtp) { + SLT_DBG(slot, "starting mtp draft: %d\n", 1); + llama_tokens draft = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + } + else { + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; - const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens(); + + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + } // ignore small drafts if (slot.params.speculative.n_min > (int) draft.size()) {