added proper KV cache management for MTP layers and slightly refactored

This commit is contained in:
Aaron Lee 2025-08-17 04:59:36 -04:00
parent 6e9bafc7a7
commit 6870f9790c
11 changed files with 137 additions and 95 deletions

View File

@ -370,25 +370,45 @@ llama_token mtp_speculative_gen_draft(
int32_t n_past,
int32_t last_tok_idx) {
const auto * model = llama_get_model(ctx);
auto * last_embd = llama_get_embeddings_tensor(ctx);
llama_token token_data[] = { id_last };
llama_pos pos_data[] = { n_past };
int32_t n_seq_id_data[] = { 1 };
llama_seq_id seq_id_data_internal[] = { 0 };
llama_seq_id* seq_id_data[] = {seq_id_data_internal};
int8_t logits_data[] = { (int8_t) (smpl != nullptr) };
GGML_ASSERT(model != nullptr);
GGML_ASSERT(last_embd != nullptr);
llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx);
llama_batch batch = {
/*.n_tokens = */ 1,
/*.token = */ token_data,
/*.embd = */ nullptr,
/*.pos = */ pos_data,
/*.n_seq_id = */ n_seq_id_data,
/*.seq_id = */ seq_id_data,
/*.logits = */ logits_data
};
common_sampler_sample(smpl, ctx, last_tok_idx, true);
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
//LOG_INF("updating kv cache for n_past: %d\n", n_past);
const auto* cur_p = common_sampler_get_candidates(smpl);
/*LOG_INF("cur_p->size: %d\n", cur_p->size);
if (!smpl) {
return -1;
}
else {
common_sampler_sample(smpl, ctx, last_tok_idx, true);
const auto* cur_p = common_sampler_get_candidates(smpl);
//for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) {
// LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
// k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
//}
const llama_token id = cur_p->data[0].id;
return id;
}
// LOG_INF("cur_p->size: %d\n", cur_p->size);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
}*/
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
// smpl will accept the token if it doesn't get rejected by main model later
@ -398,5 +418,15 @@ llama_token mtp_speculative_gen_draft(
//result.reserve(1);
//result.push_back(id);
//return result;
return id;
}
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens) {
mtp_kv_update_data token;
for (int i = 0; i < tokens.size(); ++i) {
token = tokens[i];
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx);
}
tokens.clear();
}

View File

@ -12,6 +12,12 @@ struct common_speculative_params {
float p_min = 0.75f; // min probability required to accept a token in the draft
};
struct mtp_kv_update_data {
llama_token id;
int32_t n_past;
int32_t tok_idx;
};
struct common_speculative * common_speculative_init(
struct llama_context * ctx_tgt,
struct llama_context * ctx_dft
@ -42,3 +48,5 @@ llama_tokens common_speculative_gen_draft(
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens);

View File

@ -544,9 +544,6 @@ extern "C" {
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params,
struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past);
// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,
@ -999,8 +996,6 @@ extern "C" {
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx);
//
// Vocab
//
@ -1459,16 +1454,8 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch);
LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched);
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res);
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
#ifdef __cplusplus
}

View File

@ -275,7 +275,9 @@ bool llama_batch_allocr::init(
}
}
if (!ok) {
// TEMPORARILY DISABLING THIS SANITY CHECK
// TODO: UNDO THIS IF IT WORKS
/*if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
@ -284,7 +286,7 @@ bool llama_batch_allocr::init(
__func__, s, s, p0, s, seq_pos_min(s));
return false;
}
}*/
}
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {

View File

@ -1448,8 +1448,9 @@ llm_graph_params llama_context::graph_params(
}
llm_graph_params llama_context::mtp_graph_params(
llm_graph_result* res,
const llama_ubatch& ubatch) {
llm_graph_result * res,
const llama_ubatch& ubatch,
const llama_memory_context_i * mctx) {
size_t n_nodes = std::max<uint32_t>(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer));
ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes);
return {
@ -1462,7 +1463,7 @@ llm_graph_params llama_context::mtp_graph_params(
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mctx =*/ memory->init_batch(*balloc, 1, false).get(),
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ 1,
/*.cb =*/ graph_get_cb(temp_sched),
@ -1470,6 +1471,21 @@ llm_graph_params llama_context::mtp_graph_params(
};
}
std::unique_ptr<llama_memory_context_i> llama_context::mtp_memory_batch(const llama_batch& batch_inp) {
const auto& vocab = model.vocab;
const auto& hparams = model.hparams;
const int64_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, false)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return nullptr;
}
return memory->init_batch(*balloc, 1, false);
}
ggml_status llama_context::graph_compute(
ggml_cgraph * gf,
bool batched) {
@ -2481,13 +2497,6 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}
ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) {
ctx->synchronize();
return ctx->get_embeddings_tensor();
}
// llama adapter API
int32_t llama_set_adapter_lora(
@ -2985,42 +2994,43 @@ void llama_opt_epoch(
callback_eval);
}
llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) {
return ctx->mtp_graph_params(res, ubatch);
}
ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) {
return ctx->graph_compute(gf, batched);
}
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
const auto * model = llama_get_model(ctx);
auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
llama_memory_context_ptr mctx = ctx->mtp_memory_batch(batch_inp);
const auto& ubatch_mtp = mctx->get_ubatch();
llama_ubatch ubatch_mtp;
ubatch_mtp.n_tokens = 1;
ubatch_mtp.pos = &n_past;
auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp));
auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past);
//llama_ubatch ubatch_mtp;
//ubatch_mtp.n_tokens = 1;
//ubatch_mtp.pos = &n_past;
auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get()));
ggml_backend_sched_t sched = params_mtp->sched;
auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
}
auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input");
ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
struct ggml_tensor * logits_mtp = res_mtp->get_logits();;
LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
//LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
if (logits_mtp) {
ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);

View File

@ -200,12 +200,14 @@ public:
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch);
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx);
void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i);
ggml_backend_sched_t create_temp_scheduler(size_t n_nodes);
std::unique_ptr<llama_memory_context_i> mtp_memory_batch(const llama_batch& batch_inp);
private:
llm_graph_params graph_params(
llm_graph_result * res,

View File

@ -1911,7 +1911,3 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
return relative_bucket;
}
ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) {
return res->get_logits();
}

View File

@ -41,7 +41,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
}
if (model.arch == LLM_ARCH_GLM4_MOE) {
// GLM-4.5: Only process up to last layer, skip final NextN layer
n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
n_layer_cache = hparams.n_layer;// - hparams.nextn_predict_layers;
}
// create a context for each buffer type

View File

@ -13948,7 +13948,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
struct llm_build_glm4_moe_mtp : public llm_graph_context {
llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params,
// For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past
llama_token last_token_id, int n_past
) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@ -13961,7 +13961,8 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
// ggml_set_i32(inp_pos, n_past);
ggml_tensor * inp_pos = build_inp_pos();
llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr;
//llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr;
auto * inp_attn = build_attn_inp_kv_unified();
ggml_tensor * cur;
@ -13982,9 +13983,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
ggml_tensor * prev_embedding_leaf = ggml_dup_tensor(ctx0, hidden_state_inp);
ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_leaf");
ggml_cpy(ctx0, hidden_state_inp, prev_embedding_leaf);
ggml_tensor* prev_embedding_leaf = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, model.hparams.n_embd);
ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_input");
ggml_set_input(prev_embedding_leaf);
// vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states)
ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
@ -18693,13 +18694,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
}
ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params,
ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const {
llama_token last_token_id, int n_past) const {
std::unique_ptr<llm_graph_context> llm;
switch (arch) {
case LLM_ARCH_GLM4_MOE:
{
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params, hidden_state_inp, last_token_id, n_past);
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params, last_token_id, n_past);
} break;
default:
GGML_ABORT("fatal error");
@ -19024,10 +19025,3 @@ const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_te
return model->tensors_by_name;
}
ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params,
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) {
return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past);
}

View File

@ -476,7 +476,7 @@ struct llama_model {
// TODO: move this to new llm_arch_model_i interface
ggml_cgraph * build_graph(const llm_graph_params & params) const;
ggml_cgraph * build_mtp_graph(const llm_graph_params & params,
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const;
llama_token last_token_id, int n_past) const;
private:
struct impl;

View File

@ -1278,6 +1278,7 @@ struct server_task_result_apply_lora : server_task_result {
}
};
struct server_slot {
int id;
int id_task = -1;
@ -1295,8 +1296,9 @@ struct server_slot {
common_speculative * spec = nullptr;
bool has_mtp = false;
std::vector<mtp_kv_update_data> mtp_kv_update_batch;
int32_t last_tok_idx = -1;
std::vector<common_adapter_lora_info> lora;
// the index relative to completion multi-task request
@ -1393,7 +1395,7 @@ struct server_slot {
}
bool need_embd() const {
return server_task_type_need_embd(task_type);
return server_task_type_need_embd(task_type) || has_mtp;
}
bool need_logits() const {
@ -1569,6 +1571,7 @@ struct server_slot {
}
};
struct server_metrics {
int64_t t_start = 0;
@ -1994,7 +1997,7 @@ struct server_context {
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
return false;
}
vocab = llama_model_get_vocab(model);
n_ctx = llama_n_ctx(ctx);
@ -2124,18 +2127,21 @@ struct server_context {
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
}
}
// if model has MTP and no draft model is specified...
else if (llama_model_n_nextn_layer(model) > 0) {
SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model));
slot.has_mtp = true;
// assume one speculative token (true of all well-known MTP models so far)
slot.batch_spec = llama_batch_init(2, 0, 1);
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
params_base.speculative.n_min = 0;
params_base.speculative.n_max = 1;
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
llama_set_embeddings(ctx, true);
}
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
@ -3383,7 +3389,11 @@ struct server_context {
// embedding requires all tokens in the batch to be output
const bool need_embd = server_task_type_need_embd(slot.task_type);
if (slot.has_mtp) {
slot.mtp_kv_update_batch.push_back({ cur_tok, slot.n_past, batch.n_tokens });
}
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
slot.cache_tokens.push_back(cur_tok);
slot.n_prompt_tokens_processed++;
@ -3533,6 +3543,11 @@ struct server_context {
const int tok_idx = slot.i_batch - i;
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
if (slot.has_mtp) {
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
}
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
slot.last_tok_idx = tok_idx;
@ -3571,8 +3586,6 @@ struct server_context {
}
}
SRV_DBG("starting speculative decoding: %d\n", 1);
// do speculative decoding
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
@ -3631,13 +3644,9 @@ struct server_context {
//draft.reserve(1);
//draft.push_back(draft_id);
for (const auto& str : draft) {
SLT_DBG(slot, "%s\n", str);
}
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
if (slot.params.speculative.n_min > (int)draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
continue;
}
@ -3661,8 +3670,12 @@ struct server_context {
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
// if slot has mtp
// call
if (slot.has_mtp) {
for (int32_t i = 0; i < ids.size(); ++i) {
slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + 1 + i, i });
}
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
}
slot.n_past += ids.size();
slot.n_decoded += ids.size();