lora: make sure model keep track of associated adapters

This commit is contained in:
Xuan Son Nguyen 2025-12-30 15:57:21 +01:00
parent cd78e57c3a
commit f5e8bfddc3
4 changed files with 17 additions and 9 deletions

View File

@ -413,8 +413,8 @@ static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_l
}
}
// update number of nodes used
model.n_lora_nodes += adapter.get_n_nodes();
// register adapter with model
model.loras.insert(&adapter);
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
}
@ -474,9 +474,10 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter,
}
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
// update number of nodes used
GGML_ASSERT(adapter->model.n_lora_nodes >= adapter->get_n_nodes());
adapter->model.n_lora_nodes -= adapter->get_n_nodes();
// remove adapter from associated model
auto & model = adapter->model;
GGML_ASSERT(model.loras.find(adapter) != model.loras.end());
model.loras.erase(adapter);
delete adapter;
}

View File

@ -1443,7 +1443,9 @@ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
}
uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
res += model.n_lora_nodes;
for (const auto & lora : model.loras) {
res += lora->get_n_nodes();
}
return res;
}

View File

@ -467,7 +467,11 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi
pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern;
}
llama_model::~llama_model() = default;
llama_model::~llama_model() {
for (auto * lora : loras) {
llama_adapter_lora_free(lora);
}
}
void llama_model::load_stats(llama_model_loader & ml) {
pimpl->n_elements = ml.n_elements;

View File

@ -12,6 +12,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <set>
struct llama_cparams;
struct llama_ubatch;
@ -475,8 +476,8 @@ struct llama_model {
// for quantize-stats only
std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
// for keeping track of extra nodes used by lora adapters
uint32_t n_lora_nodes = 0;
// for keeping track of associated LoRA adapters
std::set<llama_adapter_lora *> loras;
int64_t t_load_us = 0;
int64_t t_start_us = 0;