llama_mtp_start

This commit is contained in:
Xuan Son Nguyen 2026-01-17 20:54:09 +01:00
parent b4457e48bb
commit de17303dc6
4 changed files with 42 additions and 24 deletions

View File

@ -370,6 +370,7 @@ extern "C" {
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
bool is_mtp; // create context for Multi-Token Prediction (MTP)
// [EXPERIMENTAL]
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
@ -939,14 +940,6 @@ extern "C" {
struct llama_context * ctx,
struct llama_batch batch);
// Process a batch of tokens using MTP (Multi-Token Prediction).
// The input token can be either from the last llama_decode() call,
// or from the previous llama_decode_mtp() call.
// Input token order must match the output token order from the previous call.
LLAMA_API int32_t llama_decode_mtp(
struct llama_context * ctx,
struct llama_batch batch);
// Set the number of threads used for decoding
// n_threads is the number of threads used for generation (single token)
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
@ -1014,6 +1007,10 @@ extern "C" {
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
// Copy the internal MTP state from ctx_llm to ctx_mtp, ready for MTP decoding.
// This must be done before calling llama_decode() on ctx_mtp
LLAMA_API int32_t llama_mtp_start(struct llama_context * ctx_llm, struct llama_context * ctx_mtp);
//
// backend sampling API [EXPERIMENTAL]
// note: use only if the llama_context was created with at least one llama_sampler_seq_config

View File

@ -27,6 +27,8 @@ llama_context::llama_context(
// may need to be backend-dependent
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
is_mtp = params.is_mtp;
t_start_us = model.t_start_us;
t_load_us = model.t_load_us;
@ -814,6 +816,23 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
return it->second.data();
}
int32_t llama_context::cpy_mtp_state(llama_context & ctx_mtp) {
if (!ctx_mtp.is_mtp) {
LLAMA_LOG_ERROR("%s: target context is not MTP\n", __func__);
return -1;
}
if (cross.n_token == 0 || cross.n_embd == 0) {
LLAMA_LOG_ERROR("%s: no state to copy\n", __func__);
return -1;
}
// TODO: maybe std::move is better?
ctx_mtp.cross = cross;
return 0;
}
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
output_reorder();
@ -1456,7 +1475,7 @@ static void copy_tensor_async_candidates(
}
}
int llama_context::decode(const llama_batch & batch_inp, bool is_mtp) {
int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
if (is_mtp) {
@ -1464,7 +1483,7 @@ int llama_context::decode(const llama_batch & batch_inp, bool is_mtp) {
LLAMA_LOG_ERROR("%s: MTP decode called but model does not support MTP\n", __func__);
return -1;
}
if (batch_inp.n_tokens > n_ubatch()) {
if ((uint32_t)batch_inp.n_tokens > n_ubatch()) {
// TODO @ngxson : n_tokens > ubatch will mess up the llama_cross state, may need to fix it later
LLAMA_LOG_ERROR("%s: MTP decode requires n_ubatch >= n_tokens\n", __func__);
return -1;
@ -3043,6 +3062,7 @@ llama_context_params llama_context_default_params() {
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.kv_unified =*/ false,
/*.is_mtp =*/ false,
/*.sampler =*/ nullptr,
/*.n_sampler =*/ 0,
};
@ -3233,6 +3253,12 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}
int32_t llama_mtp_start(llama_context * ctx_llm, llama_context * ctx_mtp) {
ctx_llm->synchronize();
return ctx_llm->cpy_mtp_state(*ctx_mtp);
}
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
return ctx->set_sampler(seq_id, smpl);
}
@ -3553,7 +3579,7 @@ int32_t llama_encode(
int32_t llama_decode(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->decode(batch, false);
const int ret = ctx->decode(batch);
if (ret != 0 && ret != 1) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
@ -3561,17 +3587,6 @@ int32_t llama_decode(
return ret;
}
int32_t llama_decode_mtp(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->decode(batch, true);
if (ret != 0 && ret != 1) {
LLAMA_LOG_ERROR("%s: failed to decode MTP, ret = %d\n", __func__, ret);
}
return ret;
}
//
// perf
//

View File

@ -77,6 +77,8 @@ struct llama_context {
float * get_embeddings();
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
int32_t cpy_mtp_state(llama_context & ctx_mtp);
llama_token * get_sampled_tokens() const;
llama_token get_sampled_token_ith(int32_t idx);
@ -131,7 +133,7 @@ struct llama_context {
ggml_status & ret);
int encode(const llama_batch & batch_inp);
int decode(const llama_batch & batch_inp, bool is_mtp);
int decode(const llama_batch & batch_inp);
//
// state save/load
@ -349,6 +351,8 @@ private:
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
bool is_mtp = false;
bool has_evaluated_once = false;
// env: LLAMA_GRAPH_REUSE_DISABLE

View File

@ -54,7 +54,6 @@ llm_build_glm4_moe<false>::llm_build_glm4_moe(const llama_model & model, const l
template <>
llm_build_glm4_moe<true>::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
const bool use_mrope = hparams.use_mrope();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@ -73,6 +72,9 @@ llm_build_glm4_moe<true>::llm_build_glm4_moe(const llama_model & model, const ll
? mtp_layer.nextn.embed_tokens : model.tok_embd);
ggml_tensor * inp_state_embd = build_inp_cross_mtp();
// check number of input tokens
GGML_ASSERT(inp_state_embd->ne[1] == inp_token_embd->ne[1]);
inp_token_embd = build_norm(inp_token_embd, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
inp_state_embd = build_norm(inp_state_embd, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);