This commit is contained in:
Xuan Son Nguyen 2026-01-16 23:38:54 +01:00
parent 3d4b6c7fd2
commit b4457e48bb
8 changed files with 63 additions and 7 deletions

View File

@ -939,6 +939,14 @@ extern "C" {
struct llama_context * ctx,
struct llama_batch batch);
// Process a batch of tokens using MTP (Multi-Token Prediction).
// The input token can be either from the last llama_decode() call,
// or from the previous llama_decode_mtp() call.
// Input token order must match the output token order from the previous call.
LLAMA_API int32_t llama_decode_mtp(
struct llama_context * ctx,
struct llama_batch batch);
// Set the number of threads used for decoding
// n_threads is the number of threads used for generation (single token)
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)

View File

@ -1456,9 +1456,21 @@ static void copy_tensor_async_candidates(
}
}
int llama_context::decode(const llama_batch & batch_inp) {
int llama_context::decode(const llama_batch & batch_inp, bool is_mtp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
if (is_mtp) {
if (model.hparams.nextn_predict_layers == 0) {
LLAMA_LOG_ERROR("%s: MTP decode called but model does not support MTP\n", __func__);
return -1;
}
if (batch_inp.n_tokens > n_ubatch()) {
// TODO @ngxson : n_tokens > ubatch will mess up the llama_cross state, may need to fix it later
LLAMA_LOG_ERROR("%s: MTP decode requires n_ubatch >= n_tokens\n", __func__);
return -1;
}
}
if (!memory) {
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
return encode(batch_inp);
@ -1587,6 +1599,15 @@ int llama_context::decode(const llama_batch & batch_inp) {
break;
}
const bool update_mtp_state = hparams.nextn_predict_layers > 0 && n_outputs > 0;
// set MTP state if needed
if (update_mtp_state) {
cross.n_embd = hparams.get_n_embd_mtp();
cross.n_token = n_outputs;
cross.mtp_embd.resize(cross.n_embd*cross.n_token);
}
// reserve output buffer
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
@ -1615,7 +1636,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
ggml_status status;
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
llm_graph_type gtype = is_mtp ? LLM_GRAPH_TYPE_DECODER_MTP : LLM_GRAPH_TYPE_DECODER;
const auto * res = process_ubatch(ubatch, gtype, mctx.get(), status);
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
@ -1683,6 +1705,14 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);
// set MTP state if needed
if (update_mtp_state) {
const int64_t n_embd_mtp = cross.n_embd;
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_mtp <= (int64_t)cross.mtp_embd.size());
ggml_backend_tensor_get_async(backend_embd, t_embd, cross.mtp_embd.data(), 0, n_outputs*n_embd_mtp*sizeof(float));
}
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
@ -3523,7 +3553,7 @@ int32_t llama_encode(
int32_t llama_decode(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->decode(batch);
const int ret = ctx->decode(batch, false);
if (ret != 0 && ret != 1) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
@ -3531,6 +3561,17 @@ int32_t llama_decode(
return ret;
}
int32_t llama_decode_mtp(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->decode(batch, true);
if (ret != 0 && ret != 1) {
LLAMA_LOG_ERROR("%s: failed to decode MTP, ret = %d\n", __func__, ret);
}
return ret;
}
//
// perf
//

View File

@ -131,7 +131,7 @@ struct llama_context {
ggml_status & ret);
int encode(const llama_batch & batch_inp);
int decode(const llama_batch & batch_inp);
int decode(const llama_batch & batch_inp, bool is_mtp);
//
// state save/load

View File

@ -1431,7 +1431,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
}
ggml_tensor * llm_graph_context::build_inp_cross_mtp() const {
auto inp = std::make_unique<llm_graph_input_cross_mtp>(hparams.n_pos_per_embd());
auto inp = std::make_unique<llm_graph_input_cross_mtp>(cross);
auto & cur = inp->cross_mtp;

View File

@ -30,6 +30,7 @@ enum llm_graph_type {
LLM_GRAPH_TYPE_DEFAULT,
LLM_GRAPH_TYPE_ENCODER,
LLM_GRAPH_TYPE_DECODER,
LLM_GRAPH_TYPE_DECODER_MTP,
};
enum llm_ffn_op_type {
@ -444,7 +445,6 @@ class llm_graph_result;
struct llm_graph_params {
llm_arch arch = LLM_ARCH_UNKNOWN;
bool is_mtp = false;
llama_hparams hparams;
llama_cparams cparams;

View File

@ -76,6 +76,10 @@ uint32_t llama_hparams::get_n_embd_out() const {
return n_embd_out > 0 ? n_embd_out : n_embd;
}
uint32_t llama_hparams::get_n_embd_mtp() const {
return n_embd;
}
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
const uint32_t n_head_kv = this->n_head_kv(il);

View File

@ -240,6 +240,9 @@ struct llama_hparams {
// dimension of output embeddings
uint32_t get_n_embd_out() const;
// dimension of cross embeddings between main LLM and MTP
uint32_t get_n_embd_mtp() const;
// dimension of key embeddings across all k-v heads
uint32_t n_embd_k_gqa(uint32_t il = 0) const;

View File

@ -7871,7 +7871,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break;
case LLM_ARCH_GLM4_MOE:
{
if (params.is_mtp) {
if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
llm = std::make_unique<llm_build_glm4_moe<true>>(*this, params);
} else {
llm = std::make_unique<llm_build_glm4_moe<false>>(*this, params);