From cd78e57c3aeae7b56c5843f94e0e0b83a3d8ca81 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 30 Dec 2025 15:53:12 +0100 Subject: [PATCH] lora: count lora nodes in graph_max_nodes (#18469) * lora: count lora nodes in graph_max_nodes * 3 nodes per weight * 4 nodes * keep track n_lora_nodes from llama_model * fix assert * rm redundant header * common: load adapters before context creation * use 6 nodes --- common/common.cpp | 37 +++++++++++++++++++------------------ include/llama.h | 2 ++ src/llama-adapter.cpp | 15 ++++++++++++--- src/llama-adapter.h | 8 +++++++- src/llama-context.cpp | 4 +++- src/llama-model.h | 3 +++ 6 files changed, 46 insertions(+), 23 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 58fef59546..79c4756125 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1109,6 +1109,25 @@ common_init_result::common_init_result(common_params & params) : const llama_vocab * vocab = llama_model_get_vocab(model); + // load and optionally apply lora adapters (must be loaded before context creation) + for (auto & la : params.lora_adapters) { + llama_adapter_lora_ptr lora; + lora.reset(llama_adapter_lora_init(model, la.path.c_str())); + if (lora == nullptr) { + LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str()); + pimpl->model.reset(model); + return; + } + + char buf[1024]; + la.ptr = lora.get(); + llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf)); + la.task_name = buf; + llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); + la.prompt_prefix = buf; + pimpl->lora.emplace_back(std::move(lora)); // copy to list of loaded adapters + } + // updates params.sampling // TODO: fix naming common_init_sampler_from_model(model, params.sampling); @@ -1245,24 +1264,6 @@ common_init_result_ptr common_init_from_params(common_params & params) { } } - // load and optionally apply lora adapters - for (auto & la : params.lora_adapters) { - llama_adapter_lora_ptr lora; - lora.reset(llama_adapter_lora_init(model, la.path.c_str())); - if (lora == nullptr) { - LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); - return res; - } - - char buf[1024]; - la.ptr = lora.get(); - llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf)); - la.task_name = buf; - llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); - la.prompt_prefix = buf; - res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters - } - if (!params.lora_init_without_apply) { common_set_adapter_lora(lctx, params.lora_adapters); } diff --git a/include/llama.h b/include/llama.h index 4f0124fdc8..8b3c8a7b10 100644 --- a/include/llama.h +++ b/include/llama.h @@ -607,6 +607,8 @@ extern "C" { // // Load a LoRA adapter from file + // The adapter is valid as long as the associated model is not freed + // All adapters must be loaded before context creation LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( struct llama_model * model, const char * path_lora); diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp index d8eef75a7a..bdc24c2d6b 100644 --- a/src/llama-adapter.cpp +++ b/src/llama-adapter.cpp @@ -146,9 +146,11 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) { return nullptr; } -static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) { +static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_lora & adapter) { LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); + llama_model & model = adapter.model; + ggml_context * ctx_init; gguf_init_params meta_gguf_params = { /* .no_alloc = */ true, @@ -411,14 +413,17 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } } + // update number of nodes used + model.n_lora_nodes += adapter.get_n_nodes(); + LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); } llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) { - llama_adapter_lora * adapter = new llama_adapter_lora(); + llama_adapter_lora * adapter = new llama_adapter_lora(*model); try { - llama_adapter_lora_init_impl(*model, path_lora, *adapter); + llama_adapter_lora_init_impl(path_lora, *adapter); return adapter; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); @@ -469,6 +474,10 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, } void llama_adapter_lora_free(llama_adapter_lora * adapter) { + // update number of nodes used + GGML_ASSERT(adapter->model.n_lora_nodes >= adapter->get_n_nodes()); + adapter->model.n_lora_nodes -= adapter->get_n_nodes(); + delete adapter; } diff --git a/src/llama-adapter.h b/src/llama-adapter.h index 4f65247c0f..42d64a6e0b 100644 --- a/src/llama-adapter.h +++ b/src/llama-adapter.h @@ -59,6 +59,8 @@ struct llama_adapter_lora_weight { }; struct llama_adapter_lora { + llama_model & model; + // map tensor name to lora_a_b std::unordered_map ab_map; @@ -73,10 +75,14 @@ struct llama_adapter_lora { // activated lora (aLoRA) std::vector alora_invocation_tokens; - llama_adapter_lora() = default; + llama_adapter_lora(llama_model & model) : model(model) {} ~llama_adapter_lora() = default; llama_adapter_lora_weight * get_weight(ggml_tensor * w); + + uint32_t get_n_nodes() const { + return ab_map.size() * 6u; // a, b, scale, add, 2 x mul_mat + } }; using llama_adapter_loras = std::unordered_map; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 1c530fdc91..34dfcd4724 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1442,7 +1442,9 @@ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { if (model.arch == LLM_ARCH_QWEN3NEXT) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } - return std::max(1024u, 8u*model.n_tensors()); + uint32_t res = std::max(1024u, 8u*model.n_tensors()); + res += model.n_lora_nodes; + return res; } llm_graph_result * llama_context::get_gf_res_reserve() const { diff --git a/src/llama-model.h b/src/llama-model.h index dbe5edc153..f4f44a92b6 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -475,6 +475,9 @@ struct llama_model { // for quantize-stats only std::vector> tensors_by_name; + // for keeping track of extra nodes used by lora adapters + uint32_t n_lora_nodes = 0; + int64_t t_load_us = 0; int64_t t_start_us = 0;