diff --git a/common/sampling.cpp b/common/sampling.cpp index 9c04d35fd0..a5824ebeed 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -348,6 +348,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co llama_sampler_apply(chain, &cur_p); + /*for (int k = 0; k < (int)cur_p.size; ++k) { + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n", + k, 0, cur_p.data[k].id, cur_p.data[k].p); + }*/ + GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); const llama_token id = cur_p.data[cur_p.selected].id; diff --git a/common/speculative.cpp b/common/speculative.cpp index e46a0968bd..fa784f62f6 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "sampling.h" #include "../src/llama-graph.h" +#include "../src/llama-context.h" #include #include @@ -362,126 +363,40 @@ llama_tokens common_speculative_gen_draft( } -llama_tokens mtp_speculative_gen_draft( - struct common_sampler * smpl, - struct llama_context * ctx, - llama_token id_last, - int32_t n_past, - int32_t last_tok_idx) { +llama_token mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx) { - llama_tokens result; - - LOG_INF("step: '%d'\n", 1); - - // sample one token from the draft model -- this does NOT generalize to >1 MTP head - result.reserve(1); - - // need to determine which architecture we're using so we call the correct MTP model const auto * model = llama_get_model(ctx); - - LOG_INF("step: '%d'\n", 2); - - //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); - //auto * gf = model.build_graph(gparams); - - LOG_INF("step: '%d'\n", 3); - - /*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); - ret = GGML_STATUS_ALLOC_FAILED; - return nullptr; - }*/ - - //llm_graph_result res_mtp(ctx->graph_max_nodes()); - llm_graph_result * res_mtp; - llama_ubatch ubatch_mtp; - ubatch_mtp.n_tokens = 1; - ubatch_mtp.pos = &n_past; // Critical for positional encoding - - // We also need a minimal ubatch to provide positional context (RoPE) - // ubatch_mtp.tokens = &last_token_id; - // ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper - // ubatch_mtp.logits = nullptr; - // ubatch_mtp.all_pos_0 = -1; - // ubatch_mtp.all_pos_1 = -1; - // ubatch_mtp.all_seq_id = -1; - - // Manually construct the graph parameters - //const llm_graph_params params_mtp = { - // /*.arch =*/ model->arch, - // /*.hparams =*/ model->hparams, - // /*.cparams =*/ ctx->cparams, - // /*.ubatch =*/ ubatch_mtp, - // /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, - // /*.sched =*/ ctx->sched.get(), - // /*.backend_cpu =*/ ctx->backend_cpu, - // /*.cvec =*/ &ctx->cvec, - // /*.loras =*/ &ctx->loras, - // /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context - // /*.cross =*/ &ctx->cross, - // /*.n_outputs =*/ 1, - // /*.cb =*/ ctx->graph_get_cb(), - // /*.res =*/ &res_mtp, // Point to our temporary result object - //}; - llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp); - - LOG_INF("step: '%d'\n", 4); - - // ggml_cgraph* build_mtp_graph(const llm_graph_params & params, - // ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const; auto * last_embd = llama_get_embeddings_tensor(ctx); - LOG_INF("step: '%d'\n", 5); - GGML_ASSERT(model != nullptr); GGML_ASSERT(last_embd != nullptr); + llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx); - auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past); + common_sampler_sample(smpl, ctx, last_tok_idx, true); - if (!gf) { - LOG_INF("%s: failed to initialize graph\n", __func__); - //ret = GGML_STATUS_FAILED; - return result; - } + const auto* cur_p = common_sampler_get_candidates(smpl); + /*LOG_INF("cur_p->size: %d\n", cur_p->size); - LOG_INF("step: '%d'\n", 6); + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + }*/ - const auto status = llama_graph_compute(ctx, gf, false); + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; - LOG_INF("step: '%d'\n", 7); + // skip accepting draft token -- since we're only drafting one token this can't affect future outputs + // smpl will accept the token if it doesn't get rejected by main model later + // common_sampler_accept(smpl, id, true); - struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp); - float * ctx_logit_pointer = llama_get_logits(ctx); - - LOG_INF("step: '%d'\n", 8); - - if (logits_mtp) { - llama_set_logits(ctx, logits_mtp); - } - - LOG_INF("step: '%d'\n", 9); - - { - common_sampler_sample(smpl, ctx, last_tok_idx, true); - - LOG_INF("step: '%d'\n", 10); - - const auto * cur_p = common_sampler_get_candidates(smpl); - - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); - } - - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; - - // skip accepting draft token -- since we're only drafting one token this can't affect future outputs - // smpl will accept the token if it doesn't get rejected by main model later - // common_sampler_accept(smpl, id, true); - - result.push_back(id); - } - - return result; + //llama_tokens result; + //result.reserve(1); + //result.push_back(id); + //return result; + return id; } diff --git a/common/speculative.h b/common/speculative.h index 3b04890073..6ff9e822f8 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -29,7 +29,7 @@ void common_speculative_add_replacement_tgt_dft( // sample up to n_draft tokens and add them to the batch using the draft model -llama_tokens mtp_speculative_gen_draft( +llama_token mtp_speculative_gen_draft( struct common_sampler* smpl, struct llama_context* ctx, llama_token id_last, diff --git a/include/llama.h b/include/llama.h index 2134f62d52..16dc10d403 100644 --- a/include/llama.h +++ b/include/llama.h @@ -977,8 +977,6 @@ extern "C" { // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); - LLAMA_API void llama_set_logits(struct llama_context* ctx, struct ggml_tensor* logit_override); - // Get all output token embeddings. // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously @@ -1465,6 +1463,9 @@ extern "C" { LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched); + LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); + LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26c3e639d8..ca713fa389 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -523,12 +523,16 @@ float * llama_context::get_logits() { return logits; } -void llama_context::set_logits(struct ggml_tensor * logit_override) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), logit_override); +void llama_context::set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i) { + output_reorder(); + + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched_override, logit_override); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); - ggml_backend_tensor_get_async(backend_res, logit_override, logits, 0, model.vocab.n_tokens() * sizeof(float)); + int64_t j = output_ids[i]; + + ggml_backend_tensor_get_async(backend_res, logit_override, logits + j*model.vocab.n_tokens(), 0, model.vocab.n_tokens() * sizeof(float)); } float * llama_context::get_logits_ith(int32_t i) { @@ -1445,21 +1449,23 @@ llm_graph_params llama_context::graph_params( llm_graph_params llama_context::mtp_graph_params( llm_graph_result* res, - const llama_ubatch& ubatch) const { + const llama_ubatch& ubatch) { + size_t n_nodes = std::max(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer)); + ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes); return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, /*.cparams =*/ cparams, /*.ubatch =*/ ubatch, /*.gtype =*/ LLM_GRAPH_TYPE_DECODER, - /*.sched =*/ sched.get(), + /*.sched =*/ temp_sched, /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, /*.mctx =*/ memory->init_batch(*balloc, 1, false).get(), /*.cross =*/ &cross, /*.n_outputs =*/ 1, - /*.cb =*/ graph_get_cb(), + /*.cb =*/ graph_get_cb(temp_sched), /*.res =*/ res, }; } @@ -1491,8 +1497,10 @@ ggml_status llama_context::graph_compute( return status; } -llm_graph_cb llama_context::graph_get_cb() const { - return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { +llm_graph_cb llama_context::graph_get_cb(ggml_backend_sched * sched_override) const { + ggml_backend_sched * cb_sched = sched_override ? sched_override : sched.get(); + + return [=](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { if (il >= 0) { ggml_format_name(cur, "%s-%d", name, il); } else { @@ -1502,7 +1510,7 @@ llm_graph_cb llama_context::graph_get_cb() const { if (!cparams.offload_kqv) { if (strcmp(name, "kqv_merged_cont") == 0) { // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); + ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend_cpu); } } @@ -1515,7 +1523,7 @@ llm_graph_cb llama_context::graph_get_cb() const { for (const auto & backend : backends) { if (ggml_backend_get_device(backend.get()) == dev_layer) { if (ggml_backend_supports_op(backend.get(), cur)) { - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); + ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend.get()); } } } @@ -1524,6 +1532,10 @@ llm_graph_cb llama_context::graph_get_cb() const { }; } +ggml_backend_sched_t llama_context::create_temp_scheduler(size_t n_nodes) { + return ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), n_nodes, false, cparams.op_offload); +} + // // state save/load // @@ -2450,10 +2462,6 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } -void llama_set_logits(llama_context* ctx, struct ggml_tensor* logit_override) { - ctx->set_logits(logit_override); -} - float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); @@ -2985,3 +2993,37 @@ llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* re ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) { return ctx->graph_compute(gf, batched); } + +void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { + + const auto * model = llama_get_model(ctx); + + auto res_mtp = std::make_unique(ctx->graph_max_nodes()); + + llama_ubatch ubatch_mtp; + ubatch_mtp.n_tokens = 1; + ubatch_mtp.pos = &n_past; + + auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp)); + + auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past); + + ggml_backend_sched_t sched = params_mtp->sched; + + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph + ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it + + ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); + + ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors + ggml_backend_sched_graph_compute(sched, gf); // execute the graph + + struct ggml_tensor * logits_mtp = res_mtp->get_logits();; + LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); + + if (logits_mtp) { + ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); + } +} + diff --git a/src/llama-context.h b/src/llama-context.h index 44bcdf6d95..20314304c0 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -200,9 +200,11 @@ public: // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); - llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch) const; + llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch); - void set_logits(struct ggml_tensor* logit_override); + void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); + + ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); private: llm_graph_params graph_params( @@ -211,7 +213,7 @@ private: const llama_memory_context_i * mctx, llm_graph_type gtype) const; - llm_graph_cb graph_get_cb() const; + llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8a9ba84803..b0c096dec6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13950,7 +13950,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past ) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13958,22 +13957,43 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { const int il = hparams.n_layer - 1; const auto & mtp_layer = model.layers[il]; - ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - ggml_set_i32(inp_pos, n_past); - llm_graph_input_attn_no_cache * inp_attn = nullptr; + // ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); + // ggml_set_i32(inp_pos, n_past); + ggml_tensor * inp_pos = build_inp_pos(); + + llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; ggml_tensor * cur; // get MTP embedding for last (conventionally sampled) token + // ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); + // LLAMA_LOG_INFO("step: '%d'\n", 5641); + // ggml_set_i32(inp_token_id, last_token_id); + //ggml_set_no_alloc(ctx0, false); + //LLAMA_LOG_INFO("last token id: '%d'\n", last_token_id); + ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); - ggml_set_i32(inp_token_id, last_token_id); + ggml_set_name(inp_token_id, "mtp_token_id_input"); + ggml_set_input(inp_token_id); + + //ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id); + //ggml_set_no_alloc(ctx0, true); + ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * prev_embedding_leaf = ggml_dup_tensor(ctx0, hidden_state_inp); + ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_leaf"); + ggml_cpy(ctx0, hidden_state_inp, prev_embedding_leaf); + // vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states) - ggml_tensor * hidden_state_norm = build_norm(hidden_state_inp, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + //token_emb_norm = ggml_cont(ctx0, token_emb_norm); + //hidden_state_norm = ggml_cont(ctx0, hidden_state_norm); ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat + + cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj @@ -14071,7 +14091,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cur = ggml_add(ctx0, routed_out, shared_out); cb(cur, "ffn_out", il); } - cur = ggml_add(ctx0, cur, ffn_inp); cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); @@ -18680,14 +18699,12 @@ ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, switch (arch) { case LLM_ARCH_GLM4_MOE: { - printf("step: '%d'\n", 56); llm = std::make_unique(*this, params, hidden_state_inp, last_token_id, n_past); } break; default: GGML_ABORT("fatal error"); } - printf("step: '%d'\n", 57); return llm->res->get_gf(); } @@ -19009,8 +19026,8 @@ const std::vector> & llama_internal_get_te ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params, ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) { - printf("step: '%d'\n", 55); return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past); } + diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 29d551ea51..e5039fe86a 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2132,6 +2132,8 @@ struct server_context { // assume one speculative token (true of all well-known MTP models so far) slot.batch_spec = llama_batch_init(2, 0, 1); + SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); + params_base.speculative.n_min = 0; params_base.speculative.n_max = 1; } @@ -3587,9 +3589,7 @@ struct server_context { } // determine the max draft that fits the current slot state - SLT_DBG(slot, "starting mtp draft: %d\n", 2); int n_draft_max = slot.params.speculative.n_max; - SLT_DBG(slot, "starting mtp draft: %d\n", 3); // note: n_past is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -3607,14 +3607,13 @@ struct server_context { continue; } - SLT_DBG(slot, "slot has mtp: %d\n", slot.has_mtp); - llama_token id = slot.sampled; llama_tokens draft; if (slot.has_mtp) { - SLT_DBG(slot, "starting mtp draft: %d\n", 1); - llama_tokens draft = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + draft.reserve(1); + draft.push_back(draft_id); } else { struct common_speculative_params params_spec; @@ -3624,7 +3623,16 @@ struct server_context { const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + } + + //llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + //llama_tokens draft; + //draft.reserve(1); + //draft.push_back(draft_id); + + for (const auto& str : draft) { + SLT_DBG(slot, "%s\n", str); } // ignore small drafts @@ -3636,6 +3644,7 @@ struct server_context { // keep track of total number of drafted tokens tested slot.n_draft_total += draft.size(); + SLT_DBG(slot, "draft size = %d\n", draft.size()); // construct the speculation batch common_batch_clear(slot.batch_spec); @@ -3652,6 +3661,9 @@ struct server_context { // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + // if slot has mtp + // call + slot.n_past += ids.size(); slot.n_decoded += ids.size();