broad thrust of the mtp implementation

This commit is contained in:
Aaron Lee 2025-08-13 02:21:17 -04:00
parent 03231da69e
commit cf0f7c0448
9 changed files with 260 additions and 11 deletions

View File

@ -5,6 +5,7 @@
#include "log.h"
#include "common.h"
#include "sampling.h"
#include "../src/llama-graph.h"
#include <cstring>
#include <algorithm>
@ -359,3 +360,128 @@ llama_tokens common_speculative_gen_draft(
}
return result;
}
llama_tokens mtp_speculative_gen_draft(
struct common_sampler * smpl,
struct llama_context * ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx) {
llama_tokens result;
LOG_INF("step: '%d'\n", 1);
// sample one token from the draft model -- this does NOT generalize to >1 MTP head
result.reserve(1);
// need to determine which architecture we're using so we call the correct MTP model
const auto * model = llama_get_model(ctx);
LOG_INF("step: '%d'\n", 2);
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
//auto * gf = model.build_graph(gparams);
LOG_INF("step: '%d'\n", 3);
/*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
ret = GGML_STATUS_ALLOC_FAILED;
return nullptr;
}*/
//llm_graph_result res_mtp(ctx->graph_max_nodes());
llm_graph_result * res_mtp;
llama_ubatch ubatch_mtp;
ubatch_mtp.n_tokens = 1;
ubatch_mtp.pos = &n_past; // Critical for positional encoding
// We also need a minimal ubatch to provide positional context (RoPE)
// ubatch_mtp.tokens = &last_token_id;
// ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper
// ubatch_mtp.logits = nullptr;
// ubatch_mtp.all_pos_0 = -1;
// ubatch_mtp.all_pos_1 = -1;
// ubatch_mtp.all_seq_id = -1;
// Manually construct the graph parameters
//const llm_graph_params params_mtp = {
// /*.arch =*/ model->arch,
// /*.hparams =*/ model->hparams,
// /*.cparams =*/ ctx->cparams,
// /*.ubatch =*/ ubatch_mtp,
// /*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
// /*.sched =*/ ctx->sched.get(),
// /*.backend_cpu =*/ ctx->backend_cpu,
// /*.cvec =*/ &ctx->cvec,
// /*.loras =*/ &ctx->loras,
// /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context
// /*.cross =*/ &ctx->cross,
// /*.n_outputs =*/ 1,
// /*.cb =*/ ctx->graph_get_cb(),
// /*.res =*/ &res_mtp, // Point to our temporary result object
//};
llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp);
LOG_INF("step: '%d'\n", 4);
// ggml_cgraph* build_mtp_graph(const llm_graph_params & params,
// ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const;
auto * last_embd = llama_get_embeddings_tensor(ctx);
LOG_INF("step: '%d'\n", 5);
GGML_ASSERT(model != nullptr);
GGML_ASSERT(last_embd != nullptr);
auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past);
if (!gf) {
LOG_INF("%s: failed to initialize graph\n", __func__);
//ret = GGML_STATUS_FAILED;
return result;
}
LOG_INF("step: '%d'\n", 6);
const auto status = llama_graph_compute(ctx, gf, false);
LOG_INF("step: '%d'\n", 7);
struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp);
float * ctx_logit_pointer = llama_get_logits(ctx);
LOG_INF("step: '%d'\n", 8);
if (logits_mtp) {
llama_set_logits(ctx, logits_mtp);
}
LOG_INF("step: '%d'\n", 9);
{
common_sampler_sample(smpl, ctx, last_tok_idx, true);
LOG_INF("step: '%d'\n", 10);
const auto * cur_p = common_sampler_get_candidates(smpl);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
// smpl will accept the token if it doesn't get rejected by main model later
// common_sampler_accept(smpl, id, true);
result.push_back(id);
}
return result;
}

View File

@ -27,6 +27,15 @@ void common_speculative_add_replacement_tgt_dft(
struct common_speculative * spec,
const char *source, const char *dest);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens mtp_speculative_gen_draft(
struct common_sampler* smpl,
struct llama_context* ctx,
llama_token id_last,
int32_t n_past,
int32_t last_tok_idx);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,

View File

@ -544,12 +544,17 @@ extern "C" {
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params,
struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past);
// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,
const char * fname_out,
const llama_model_quantize_params * params);
//
// Adapters
//
@ -972,6 +977,8 @@ extern "C" {
// returns NULL for invalid ids.
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
LLAMA_API void llama_set_logits(struct llama_context* ctx, struct ggml_tensor* logit_override);
// Get all output token embeddings.
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
@ -994,6 +1001,8 @@ extern "C" {
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx);
//
// Vocab
//
@ -1452,6 +1461,14 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch);
LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched);
LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res);
#ifdef __cplusplus
}
#endif

View File

@ -6,6 +6,7 @@
#include "llama-memory.h"
#include "llama-mmap.h"
#include "llama-model.h"
#include "llama-graph.h"
#include <cinttypes>
#include <cstring>
@ -522,6 +523,14 @@ float * llama_context::get_logits() {
return logits;
}
void llama_context::set_logits(struct ggml_tensor * logit_override) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), logit_override);
GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr);
ggml_backend_tensor_get_async(backend_res, logit_override, logits, 0, model.vocab.n_tokens() * sizeof(float));
}
float * llama_context::get_logits_ith(int32_t i) {
int64_t j = -1;
@ -617,6 +626,10 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
return it->second.data();
}
ggml_tensor * llama_context::get_embeddings_tensor() {
return embd_tensor;
}
void llama_context::attach_threadpool(
ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch) {
@ -1113,6 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
embd_tensor = res->get_embd();
if (t_embd && res->get_embd_pooled()) {
t_embd = res->get_embd_pooled();
@ -1429,6 +1443,27 @@ llm_graph_params llama_context::graph_params(
};
}
llm_graph_params llama_context::mtp_graph_params(
llm_graph_result* res,
const llama_ubatch& ubatch) const {
return {
/*.arch =*/ model.arch,
/*.hparams =*/ model.hparams,
/*.cparams =*/ cparams,
/*.ubatch =*/ ubatch,
/*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ memory->init_batch(*balloc, 1, false).get(),
/*.cross =*/ &cross,
/*.n_outputs =*/ 1,
/*.cb =*/ graph_get_cb(),
/*.res =*/ res,
};
}
ggml_status llama_context::graph_compute(
ggml_cgraph * gf,
bool batched) {
@ -2233,6 +2268,7 @@ void llama_context::opt_epoch(
llama_batch_free(batch);
}
//
// interface implementation
//
@ -2274,6 +2310,8 @@ llama_context_params llama_context_default_params() {
return result;
}
llama_context * llama_init_from_model(
llama_model * model,
llama_context_params params) {
@ -2412,6 +2450,11 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
return ctx->get_logits_ith(i);
}
void llama_set_logits(llama_context* ctx, struct ggml_tensor* logit_override) {
ctx->set_logits(logit_override);
}
float * llama_get_embeddings(llama_context * ctx) {
ctx->synchronize();
@ -2430,6 +2473,13 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}
ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) {
ctx->synchronize();
return ctx->get_embeddings_tensor();
}
// llama adapter API
int32_t llama_set_adapter_lora(
@ -2926,3 +2976,12 @@ void llama_opt_epoch(
callback_train,
callback_eval);
}
llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) {
return ctx->mtp_graph_params(res, ubatch);
}
ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) {
return ctx->graph_compute(gf, batched);
}

View File

@ -59,6 +59,7 @@ struct llama_context {
float * get_embeddings();
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
ggml_tensor * get_embeddings_tensor();
void attach_threadpool(
ggml_threadpool_t threadpool,
@ -199,6 +200,10 @@ public:
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch) const;
void set_logits(struct ggml_tensor* logit_override);
private:
llm_graph_params graph_params(
llm_graph_result * res,
@ -240,6 +245,7 @@ private:
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
size_t embd_size = 0; // capacity (of floats) for embeddings
float * embd = nullptr;
ggml_tensor * embd_tensor = nullptr;
// sequence embeddings output (map of [n_embd] vectors)
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
@ -308,3 +314,4 @@ private:
mutable int32_t n_reused = 0; // number of times the previous graph was reused
};

View File

@ -1911,3 +1911,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
return relative_bucket;
}
ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) {
return res->get_logits();
}

View File

@ -818,3 +818,4 @@ struct llm_graph_context {
// TODO: better name
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);

View File

@ -18673,19 +18673,21 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
return llm->res->get_gf();
}
ggml_cgraph* llama_model::build_mtp_graph(const llm_graph_params& params,
ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params,
ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const {
std::unique_ptr<llm_graph_context> llm;
switch (arch) {
case LLM_ARCH_GLM4_MOE:
{
printf("step: '%d'\n", 56);
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params, hidden_state_inp, last_token_id, n_past);
} break;
default:
GGML_ABORT("fatal error");
}
printf("step: '%d'\n", 57);
return llm->res->get_gf();
}
@ -19004,3 +19006,11 @@ bool llama_model_is_diffusion(const llama_model * model) {
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
return model->tensors_by_name;
}
ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params,
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) {
printf("step: '%d'\n", 55);
return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past);
}

View File

@ -1294,7 +1294,8 @@ struct server_slot {
mtmd_context * mctx = nullptr;
common_speculative * spec = nullptr;
bool has_mtp = false;
bool has_mtp = false;
int32_t last_tok_idx = -1;
std::vector<common_adapter_lora_info> lora;
@ -1432,8 +1433,8 @@ struct server_slot {
}
bool can_speculate() const {
// return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt;
return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt;
return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt;
// return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt;
}
void add_token(const completion_token_output & token) {
@ -1993,7 +1994,7 @@ struct server_context {
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
return false;
}
vocab = llama_model_get_vocab(model);
n_ctx = llama_n_ctx(ctx);
@ -3531,6 +3532,7 @@ struct server_context {
const int tok_idx = slot.i_batch - i;
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
slot.last_tok_idx = tok_idx;
slot.i_batch = -1;
@ -3567,6 +3569,8 @@ struct server_context {
}
}
SRV_DBG("starting speculative decoding: %d\n", 1);
// do speculative decoding
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
@ -3583,7 +3587,9 @@ struct server_context {
}
// determine the max draft that fits the current slot state
SLT_DBG(slot, "starting mtp draft: %d\n", 2);
int n_draft_max = slot.params.speculative.n_max;
SLT_DBG(slot, "starting mtp draft: %d\n", 3);
// note: n_past is not yet increased for the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
@ -3601,15 +3607,25 @@ struct server_context {
continue;
}
SLT_DBG(slot, "slot has mtp: %d\n", slot.has_mtp);
llama_token id = slot.sampled;
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
llama_tokens draft;
if (slot.has_mtp) {
SLT_DBG(slot, "starting mtp draft: %d\n", 1);
llama_tokens draft = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx);
}
else {
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
}
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {