add llama_context graph_type

This commit is contained in:
Xuan Son Nguyen 2026-02-06 23:29:23 +01:00
parent a2860dc85e
commit 64f05859db
3 changed files with 30 additions and 8 deletions

View File

@ -194,6 +194,13 @@ extern "C" {
LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
};
enum llama_graph_type {
LLAMA_GRAPH_TYPE_DEFAULT,
LLAMA_GRAPH_TYPE_ENCODER,
LLAMA_GRAPH_TYPE_DECODER,
LLAMA_GRAPH_TYPE_DECODER_MTP,
};
// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)
typedef struct llama_token_data {
llama_token id; // token id
@ -370,13 +377,14 @@ extern "C" {
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
bool is_mtp; // create context for Multi-Token Prediction (MTP)
// [EXPERIMENTAL]
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
struct llama_sampler_seq_config * samplers;
size_t n_samplers;
llama_graph_type graph_type; // type of the computation graph to be used
};
// model quantization parameters

View File

@ -27,8 +27,6 @@ llama_context::llama_context(
// may need to be backend-dependent
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
is_mtp = params.is_mtp;
t_start_us = model.t_start_us;
t_load_us = model.t_load_us;
@ -189,6 +187,23 @@ llama_context::llama_context(
}
}
switch (params.graph_type) {
case LLAMA_GRAPH_TYPE_DEFAULT:
gtype = LLM_GRAPH_TYPE_DEFAULT;
break;
case LLAMA_GRAPH_TYPE_ENCODER:
gtype = LLM_GRAPH_TYPE_ENCODER;
break;
case LLAMA_GRAPH_TYPE_DECODER:
gtype = LLM_GRAPH_TYPE_DECODER;
break;
case LLAMA_GRAPH_TYPE_DECODER_MTP:
gtype = LLM_GRAPH_TYPE_DECODER_MTP;
break;
default:
throw std::runtime_error("invalid graph type");
}
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
@ -814,7 +829,7 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
}
int32_t llama_context::cpy_mtp_state(llama_context & ctx_mtp) {
if (!ctx_mtp.is_mtp) {
if (ctx_mtp.gtype != LLM_GRAPH_TYPE_DECODER_MTP) {
LLAMA_LOG_ERROR("%s: target context is not MTP\n", __func__);
return -1;
}
@ -1488,7 +1503,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
if (is_mtp) {
if (gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
if (model.hparams.nextn_predict_layers == 0) {
LLAMA_LOG_ERROR("%s: MTP decode called but model does not support MTP\n", __func__);
return -1;
@ -1665,7 +1680,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
ggml_status status;
llm_graph_type gtype = is_mtp ? LLM_GRAPH_TYPE_DECODER_MTP : LLM_GRAPH_TYPE_DECODER;
const auto * res = process_ubatch(ubatch, gtype, mctx.get(), status);
if (!res) {
@ -3029,9 +3043,9 @@ llama_context_params llama_context_default_params() {
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.kv_unified =*/ false,
/*.is_mtp =*/ false,
/*.sampler =*/ nullptr,
/*.n_sampler =*/ 0,
/*.graph_type =*/ LLAMA_GRAPH_TYPE_DEFAULT,
};
return result;

View File

@ -351,7 +351,7 @@ private:
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;
bool is_mtp = false;
llm_graph_type gtype;
bool has_evaluated_once = false;