lora: make sure model keep track of associated adapters
This commit is contained in:
parent
cd78e57c3a
commit
f5e8bfddc3
|
|
@ -413,8 +413,8 @@ static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_l
|
|||
}
|
||||
}
|
||||
|
||||
// update number of nodes used
|
||||
model.n_lora_nodes += adapter.get_n_nodes();
|
||||
// register adapter with model
|
||||
model.loras.insert(&adapter);
|
||||
|
||||
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
|
||||
}
|
||||
|
|
@ -474,9 +474,10 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter,
|
|||
}
|
||||
|
||||
void llama_adapter_lora_free(llama_adapter_lora * adapter) {
|
||||
// update number of nodes used
|
||||
GGML_ASSERT(adapter->model.n_lora_nodes >= adapter->get_n_nodes());
|
||||
adapter->model.n_lora_nodes -= adapter->get_n_nodes();
|
||||
// remove adapter from associated model
|
||||
auto & model = adapter->model;
|
||||
GGML_ASSERT(model.loras.find(adapter) != model.loras.end());
|
||||
model.loras.erase(adapter);
|
||||
|
||||
delete adapter;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1443,7 +1443,9 @@ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
|
|||
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
|
||||
}
|
||||
uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
|
||||
res += model.n_lora_nodes;
|
||||
for (const auto & lora : model.loras) {
|
||||
res += lora->get_n_nodes();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -467,7 +467,11 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi
|
|||
pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern;
|
||||
}
|
||||
|
||||
llama_model::~llama_model() = default;
|
||||
llama_model::~llama_model() {
|
||||
for (auto * lora : loras) {
|
||||
llama_adapter_lora_free(lora);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_model::load_stats(llama_model_loader & ml) {
|
||||
pimpl->n_elements = ml.n_elements;
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
struct llama_cparams;
|
||||
struct llama_ubatch;
|
||||
|
|
@ -475,8 +476,8 @@ struct llama_model {
|
|||
// for quantize-stats only
|
||||
std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
|
||||
|
||||
// for keeping track of extra nodes used by lora adapters
|
||||
uint32_t n_lora_nodes = 0;
|
||||
// for keeping track of associated LoRA adapters
|
||||
std::set<llama_adapter_lora *> loras;
|
||||
|
||||
int64_t t_load_us = 0;
|
||||
int64_t t_start_us = 0;
|
||||
|
|
|
|||
Loading…
Reference in New Issue