diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp index bdc24c2d6b..5ff22b18ab 100644 --- a/src/llama-adapter.cpp +++ b/src/llama-adapter.cpp @@ -413,8 +413,8 @@ static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_l } } - // update number of nodes used - model.n_lora_nodes += adapter.get_n_nodes(); + // register adapter with model + model.loras.insert(&adapter); LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); } @@ -474,9 +474,10 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, } void llama_adapter_lora_free(llama_adapter_lora * adapter) { - // update number of nodes used - GGML_ASSERT(adapter->model.n_lora_nodes >= adapter->get_n_nodes()); - adapter->model.n_lora_nodes -= adapter->get_n_nodes(); + // remove adapter from associated model + auto & model = adapter->model; + GGML_ASSERT(model.loras.find(adapter) != model.loras.end()); + model.loras.erase(adapter); delete adapter; } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 34dfcd4724..bcea01a997 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1443,7 +1443,9 @@ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { return std::max(n_tokens * 40, 32u * model.n_tensors()); } uint32_t res = std::max(1024u, 8u*model.n_tensors()); - res += model.n_lora_nodes; + for (const auto & lora : model.loras) { + res += lora->get_n_nodes(); + } return res; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5e664c8c57..1b220af83e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -467,7 +467,11 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } -llama_model::~llama_model() = default; +llama_model::~llama_model() { + for (auto * lora : loras) { + llama_adapter_lora_free(lora); + } +} void llama_model::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; diff --git a/src/llama-model.h b/src/llama-model.h index f4f44a92b6..838d9cd6e5 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -12,6 +12,7 @@ #include #include #include +#include struct llama_cparams; struct llama_ubatch; @@ -475,8 +476,8 @@ struct llama_model { // for quantize-stats only std::vector> tensors_by_name; - // for keeping track of extra nodes used by lora adapters - uint32_t n_lora_nodes = 0; + // for keeping track of associated LoRA adapters + std::set loras; int64_t t_load_us = 0; int64_t t_start_us = 0;