diff --git a/common/speculative.cpp b/common/speculative.cpp index fa784f62f6..9f8384abb1 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -370,25 +370,45 @@ llama_token mtp_speculative_gen_draft( int32_t n_past, int32_t last_tok_idx) { - const auto * model = llama_get_model(ctx); - auto * last_embd = llama_get_embeddings_tensor(ctx); + llama_token token_data[] = { id_last }; + llama_pos pos_data[] = { n_past }; + int32_t n_seq_id_data[] = { 1 }; + llama_seq_id seq_id_data_internal[] = { 0 }; + llama_seq_id* seq_id_data[] = {seq_id_data_internal}; + int8_t logits_data[] = { (int8_t) (smpl != nullptr) }; - GGML_ASSERT(model != nullptr); - GGML_ASSERT(last_embd != nullptr); - llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx); + llama_batch batch = { + /*.n_tokens = */ 1, + /*.token = */ token_data, + /*.embd = */ nullptr, + /*.pos = */ pos_data, + /*.n_seq_id = */ n_seq_id_data, + /*.seq_id = */ seq_id_data, + /*.logits = */ logits_data + }; - common_sampler_sample(smpl, ctx, last_tok_idx, true); + llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); + //LOG_INF("updating kv cache for n_past: %d\n", n_past); - const auto* cur_p = common_sampler_get_candidates(smpl); - /*LOG_INF("cur_p->size: %d\n", cur_p->size); + if (!smpl) { + return -1; + } + else { + common_sampler_sample(smpl, ctx, last_tok_idx, true); + const auto* cur_p = common_sampler_get_candidates(smpl); + + //for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { + // LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + // k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + //} + + const llama_token id = cur_p->data[0].id; + return id; + } + // LOG_INF("cur_p->size: %d\n", cur_p->size); - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); - }*/ // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; // skip accepting draft token -- since we're only drafting one token this can't affect future outputs // smpl will accept the token if it doesn't get rejected by main model later @@ -398,5 +418,15 @@ llama_token mtp_speculative_gen_draft( //result.reserve(1); //result.push_back(id); //return result; - return id; +} + + +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens) { + mtp_kv_update_data token; + for (int i = 0; i < tokens.size(); ++i) { + token = tokens[i]; + mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx); + } + + tokens.clear(); } diff --git a/common/speculative.h b/common/speculative.h index 6ff9e822f8..786f3ad1e8 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -12,6 +12,12 @@ struct common_speculative_params { float p_min = 0.75f; // min probability required to accept a token in the draft }; +struct mtp_kv_update_data { + llama_token id; + int32_t n_past; + int32_t tok_idx; +}; + struct common_speculative * common_speculative_init( struct llama_context * ctx_tgt, struct llama_context * ctx_dft @@ -42,3 +48,5 @@ llama_tokens common_speculative_gen_draft( struct common_speculative_params params, const llama_tokens & prompt, llama_token id_last); + +void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens); diff --git a/include/llama.h b/include/llama.h index 16dc10d403..1de8a963cc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -544,9 +544,6 @@ extern "C" { // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); - LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params, - struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past); - // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -999,8 +996,6 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); - LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx); - // // Vocab // @@ -1459,16 +1454,8 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch); - - LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched); - LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); - - LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res); - - + const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); #ifdef __cplusplus } diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 8698d89ace..ff73429301 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -275,7 +275,9 @@ bool llama_batch_allocr::init( } } - if (!ok) { + // TEMPORARILY DISABLING THIS SANITY CHECK + // TODO: UNDO THIS IF IT WORKS + /*if (!ok) { LLAMA_LOG_ERROR( "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" @@ -284,7 +286,7 @@ bool llama_batch_allocr::init( __func__, s, s, p0, s, seq_pos_min(s)); return false; - } + }*/ } if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ca713fa389..34d514387b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1448,8 +1448,9 @@ llm_graph_params llama_context::graph_params( } llm_graph_params llama_context::mtp_graph_params( - llm_graph_result* res, - const llama_ubatch& ubatch) { + llm_graph_result * res, + const llama_ubatch& ubatch, + const llama_memory_context_i * mctx) { size_t n_nodes = std::max(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer)); ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes); return { @@ -1462,7 +1463,7 @@ llm_graph_params llama_context::mtp_graph_params( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.mctx =*/ memory->init_batch(*balloc, 1, false).get(), + /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.n_outputs =*/ 1, /*.cb =*/ graph_get_cb(temp_sched), @@ -1470,6 +1471,21 @@ llm_graph_params llama_context::mtp_graph_params( }; } +std::unique_ptr llama_context::mtp_memory_batch(const llama_batch& batch_inp) { + const auto& vocab = model.vocab; + const auto& hparams = model.hparams; + + const int64_t n_vocab = vocab.n_tokens(); + const int64_t n_embd = hparams.n_embd; + + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, false)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return nullptr; + } + + return memory->init_batch(*balloc, 1, false); +} + ggml_status llama_context::graph_compute( ggml_cgraph * gf, bool batched) { @@ -2481,13 +2497,6 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } -ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) { - ctx->synchronize(); - - return ctx->get_embeddings_tensor(); -} - - // llama adapter API int32_t llama_set_adapter_lora( @@ -2985,42 +2994,43 @@ void llama_opt_epoch( callback_eval); } -llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) { - return ctx->mtp_graph_params(res, ubatch); -} - - -ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) { - return ctx->graph_compute(gf, batched); -} - void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { + const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { const auto * model = llama_get_model(ctx); auto res_mtp = std::make_unique(ctx->graph_max_nodes()); + llama_memory_context_ptr mctx = ctx->mtp_memory_batch(batch_inp); + const auto& ubatch_mtp = mctx->get_ubatch(); - llama_ubatch ubatch_mtp; - ubatch_mtp.n_tokens = 1; - ubatch_mtp.pos = &n_past; - - auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp)); - - auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past); + //llama_ubatch ubatch_mtp; + //ubatch_mtp.n_tokens = 1; + //ubatch_mtp.pos = &n_past; + auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get())); ggml_backend_sched_t sched = params_mtp->sched; + auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); + + if (mctx && !mctx->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); + } + + auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); - ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors + + ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input"); + ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors + ggml_backend_sched_graph_compute(sched, gf); // execute the graph struct ggml_tensor * logits_mtp = res_mtp->get_logits();; - LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); + //LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); if (logits_mtp) { ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); diff --git a/src/llama-context.h b/src/llama-context.h index 20314304c0..e8ea3a4c9b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -200,12 +200,14 @@ public: // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); - llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch); + llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx); void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); + std::unique_ptr mtp_memory_batch(const llama_batch& batch_inp); + private: llm_graph_params graph_params( llm_graph_result * res, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b5184e4559..053c72d6dc 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1911,7 +1911,3 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck return relative_bucket; } - -ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) { - return res->get_logits(); -} diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index e539142e6b..ed6cf969d4 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -41,7 +41,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( } if (model.arch == LLM_ARCH_GLM4_MOE) { // GLM-4.5: Only process up to last layer, skip final NextN layer - n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; + n_layer_cache = hparams.n_layer;// - hparams.nextn_predict_layers; } // create a context for each buffer type diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b0c096dec6..04743e01f3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13948,7 +13948,7 @@ struct llm_build_glm4_moe : public llm_graph_context { struct llm_build_glm4_moe_mtp : public llm_graph_context { llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params, // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization - ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past + llama_token last_token_id, int n_past ) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13961,7 +13961,8 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { // ggml_set_i32(inp_pos, n_past); ggml_tensor * inp_pos = build_inp_pos(); - llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; + //llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; + auto * inp_attn = build_attn_inp_kv_unified(); ggml_tensor * cur; @@ -13982,9 +13983,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); - ggml_tensor * prev_embedding_leaf = ggml_dup_tensor(ctx0, hidden_state_inp); - ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_leaf"); - ggml_cpy(ctx0, hidden_state_inp, prev_embedding_leaf); + ggml_tensor* prev_embedding_leaf = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, model.hparams.n_embd); + ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_input"); + ggml_set_input(prev_embedding_leaf); // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states) ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); @@ -18693,13 +18694,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, - ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const { + llama_token last_token_id, int n_past) const { std::unique_ptr llm; switch (arch) { case LLM_ARCH_GLM4_MOE: { - llm = std::make_unique(*this, params, hidden_state_inp, last_token_id, n_past); + llm = std::make_unique(*this, params, last_token_id, n_past); } break; default: GGML_ABORT("fatal error"); @@ -19024,10 +19025,3 @@ const std::vector> & llama_internal_get_te return model->tensors_by_name; } -ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params, - ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) { - - return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past); -} - - diff --git a/src/llama-model.h b/src/llama-model.h index 77a18aca71..b28a37488f 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -476,7 +476,7 @@ struct llama_model { // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; ggml_cgraph * build_mtp_graph(const llm_graph_params & params, - ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; + llama_token last_token_id, int n_past) const; private: struct impl; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index e5039fe86a..b85fa4e769 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1278,6 +1278,7 @@ struct server_task_result_apply_lora : server_task_result { } }; + struct server_slot { int id; int id_task = -1; @@ -1295,8 +1296,9 @@ struct server_slot { common_speculative * spec = nullptr; bool has_mtp = false; + std::vector mtp_kv_update_batch; int32_t last_tok_idx = -1; - + std::vector lora; // the index relative to completion multi-task request @@ -1393,7 +1395,7 @@ struct server_slot { } bool need_embd() const { - return server_task_type_need_embd(task_type); + return server_task_type_need_embd(task_type) || has_mtp; } bool need_logits() const { @@ -1569,6 +1571,7 @@ struct server_slot { } }; + struct server_metrics { int64_t t_start = 0; @@ -1994,7 +1997,7 @@ struct server_context { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); return false; } - + vocab = llama_model_get_vocab(model); n_ctx = llama_n_ctx(ctx); @@ -2124,18 +2127,21 @@ struct server_context { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } } - + // if model has MTP and no draft model is specified... else if (llama_model_n_nextn_layer(model) > 0) { SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); slot.has_mtp = true; - + // assume one speculative token (true of all well-known MTP models so far) slot.batch_spec = llama_batch_init(2, 0, 1); SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); params_base.speculative.n_min = 0; params_base.speculative.n_max = 1; + + SRV_INF("%s\n", "MTP needs embeddings on decode, enabling"); + llama_set_embeddings(ctx, true); } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); @@ -3383,7 +3389,11 @@ struct server_context { // embedding requires all tokens in the batch to be output const bool need_embd = server_task_type_need_embd(slot.task_type); + if (slot.has_mtp) { + slot.mtp_kv_update_batch.push_back({ cur_tok, slot.n_past, batch.n_tokens }); + } common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); + slot.cache_tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -3533,6 +3543,11 @@ struct server_context { const int tok_idx = slot.i_batch - i; + // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation + if (slot.has_mtp) { + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + } + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.last_tok_idx = tok_idx; @@ -3571,8 +3586,6 @@ struct server_context { } } - SRV_DBG("starting speculative decoding: %d\n", 1); - // do speculative decoding for (auto & slot : slots) { if (!slot.is_processing() || !slot.can_speculate()) { @@ -3631,13 +3644,9 @@ struct server_context { //draft.reserve(1); //draft.push_back(draft_id); - for (const auto& str : draft) { - SLT_DBG(slot, "%s\n", str); - } - // ignore small drafts - if (slot.params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + if (slot.params.speculative.n_min > (int)draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min); continue; } @@ -3661,8 +3670,12 @@ struct server_context { // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - // if slot has mtp - // call + if (slot.has_mtp) { + for (int32_t i = 0; i < ids.size(); ++i) { + slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + 1 + i, i }); + } + mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch); + } slot.n_past += ids.size(); slot.n_decoded += ids.size();