Make graph_max_nodes vary by ubatch size (#17794)

* Make graph_max_nodes vary by ubatch size for models where chunking might explode the graph

* Update src/llama-context.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Add missing const

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Piotr Wilkin (ilintar) 2025-12-08 14:32:41 +01:00 committed by GitHub
parent 636fc17a37
commit e4e9c4329c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 7 deletions

View File

@ -248,7 +248,10 @@ llama_context::llama_context(
LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size());
const size_t max_nodes = this->graph_max_nodes();
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
const size_t max_nodes = this->graph_max_nodes(n_tokens);
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
@ -300,9 +303,6 @@ llama_context::llama_context(
cross.v_embd.clear();
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
// avoid reserving graphs with zero outputs - assume one output per sequence
n_outputs = n_seqs;
@ -1386,9 +1386,9 @@ void llama_context::output_reorder() {
// graph
//
uint32_t llama_context::graph_max_nodes() const {
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
if (model.arch == LLM_ARCH_QWEN3NEXT) {
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
}
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
}

View File

@ -197,7 +197,7 @@ private:
//
public:
uint32_t graph_max_nodes() const;
uint32_t graph_max_nodes(uint32_t n_tokens) const;
// can reuse the llm_graph_result instance of the context (for example to update a memory module)
llm_graph_result * get_gf_res_reserve() const;