From 64f05859dbfcb8dfff507eb58cc8c00924271dae Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 6 Feb 2026 23:29:23 +0100 Subject: [PATCH] add llama_context graph_type --- include/llama.h | 10 +++++++++- src/llama-context.cpp | 26 ++++++++++++++++++++------ src/llama-context.h | 2 +- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/include/llama.h b/include/llama.h index 0843f4119a..fa1872add1 100644 --- a/include/llama.h +++ b/include/llama.h @@ -194,6 +194,13 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported }; + enum llama_graph_type { + LLAMA_GRAPH_TYPE_DEFAULT, + LLAMA_GRAPH_TYPE_ENCODER, + LLAMA_GRAPH_TYPE_DECODER, + LLAMA_GRAPH_TYPE_DECODER_MTP, + }; + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id @@ -370,13 +377,14 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 - bool is_mtp; // create context for Multi-Token Prediction (MTP) // [EXPERIMENTAL] // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; + + llama_graph_type graph_type; // type of the computation graph to be used }; // model quantization parameters diff --git a/src/llama-context.cpp b/src/llama-context.cpp index c513e48dca..67947894f0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -27,8 +27,6 @@ llama_context::llama_context( // may need to be backend-dependent LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); - is_mtp = params.is_mtp; - t_start_us = model.t_start_us; t_load_us = model.t_load_us; @@ -189,6 +187,23 @@ llama_context::llama_context( } } + switch (params.graph_type) { + case LLAMA_GRAPH_TYPE_DEFAULT: + gtype = LLM_GRAPH_TYPE_DEFAULT; + break; + case LLAMA_GRAPH_TYPE_ENCODER: + gtype = LLM_GRAPH_TYPE_ENCODER; + break; + case LLAMA_GRAPH_TYPE_DECODER: + gtype = LLM_GRAPH_TYPE_DECODER; + break; + case LLAMA_GRAPH_TYPE_DECODER_MTP: + gtype = LLM_GRAPH_TYPE_DECODER_MTP; + break; + default: + throw std::runtime_error("invalid graph type"); + } + LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq); @@ -814,7 +829,7 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { } int32_t llama_context::cpy_mtp_state(llama_context & ctx_mtp) { - if (!ctx_mtp.is_mtp) { + if (ctx_mtp.gtype != LLM_GRAPH_TYPE_DECODER_MTP) { LLAMA_LOG_ERROR("%s: target context is not MTP\n", __func__); return -1; } @@ -1488,7 +1503,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map