diff --git a/common/common.cpp b/common/common.cpp index ec15804c91..26edcc383f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1223,7 +1223,7 @@ common_init_result_ptr common_init_from_params(common_params & params) { return res; } - int err = llama_apply_adapter_cvec( + int err = llama_set_adapter_cvec( lctx, cvec.data.data(), cvec.data.size(), @@ -1325,12 +1325,15 @@ std::string get_model_endpoint() { } void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { - llama_clear_adapter_lora(ctx); - for (auto & la : lora) { - if (la.scale != 0.0f) { - llama_set_adapter_lora(ctx, la.ptr, la.scale); - } + std::vector loras; + std::vector scales; + + for (auto & la: lora) { + loras.push_back(la.ptr); + scales.push_back(la.scale); } + + llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data()); } struct llama_model_params common_model_params_to_llama(common_params & params) { diff --git a/include/llama.h b/include/llama.h index 305623127c..d2d7f59ebc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -656,21 +656,12 @@ extern "C" { // The following functions operate on a llama_context, hence the naming: llama_verb_... - // Add a loaded LoRA adapter to given context - // This will not modify model's weight - LLAMA_API int32_t llama_set_adapter_lora( + // Set LoRa adapters on the context. Will only modify if the adapters currently in context are different. + LLAMA_API int32_t llama_set_adapters_lora( struct llama_context * ctx, - struct llama_adapter_lora * adapter, - float scale); - - // Remove a specific LoRA adapter from given context - // Return -1 if the adapter is not present in the context - LLAMA_API int32_t llama_rm_adapter_lora( - struct llama_context * ctx, - struct llama_adapter_lora * adapter); - - // Remove all LoRA adapters from given context - LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx); + struct llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales); // Apply a loaded control vector to a llama_context, or if data is NULL, clear // the currently loaded vector. @@ -678,7 +669,7 @@ extern "C" { // to an n_embd x n_layers buffer starting from layer 1. // il_start and il_end are the layer range the vector should apply to (both inclusive) // See llama_control_vector_load in common to load a control vector. - LLAMA_API int32_t llama_apply_adapter_cvec( + LLAMA_API int32_t llama_set_adapter_cvec( struct llama_context * ctx, const float * data, size_t len, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6b43ca1926..ac17e1a0fe 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1057,51 +1057,43 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { return true; } -void llama_context::set_adapter_lora( - llama_adapter_lora * adapter, - float scale) { - LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); +void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - if (auto it = loras.find(adapter); it != loras.end()) { - if (it->second == scale) { - return; - } - } - - loras[adapter] = scale; - - sched_need_reserve = true; -} - -bool llama_context::rm_adapter_lora( - llama_adapter_lora * adapter) { - LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); - - auto it = loras.find(adapter); - if (it != loras.end()) { - loras.erase(it); - - sched_need_reserve = true; - - return true; - } - - return false; -} - -void llama_context::clear_adapter_lora() { - LLAMA_LOG_DEBUG("%s: call\n", __func__); - - if (loras.empty()) { + if (adapters_lora_are_same(adapters, n_adapters, scales)) { return; } loras.clear(); + for (size_t i = 0; i < n_adapters; i ++) { + if (scales[i] != 0.0f) { + loras[adapters[i]] = scales[i]; + } + } + sched_need_reserve = true; } -bool llama_context::apply_adapter_cvec( +bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); + + if (n_adapters != loras.size()) { + return false; + } + + for (size_t i = 0; i < n_adapters; i ++) { + auto it = loras.find(adapters[i]); + + if (it == loras.end() || it->second != scales[i]) { + return false; + } + } + + return true; +} + +bool llama_context::set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -3209,35 +3201,28 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { // llama adapter API -int32_t llama_set_adapter_lora( +int32_t llama_set_adapters_lora( llama_context * ctx, - llama_adapter_lora * adapter, - float scale) { - ctx->set_adapter_lora(adapter, scale); + llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales) { + if (adapters == nullptr || scales == nullptr) { + GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call"); + } + + ctx->set_adapters_lora(adapters, n_adapters, scales); return 0; } -int32_t llama_rm_adapter_lora( - llama_context * ctx, - llama_adapter_lora * adapter) { - bool res = ctx->rm_adapter_lora(adapter); - - return res ? 0 : -1; -} - -void llama_clear_adapter_lora(llama_context * ctx) { - ctx->clear_adapter_lora(); -} - -int32_t llama_apply_adapter_cvec( +int32_t llama_set_adapter_cvec( llama_context * ctx, - const float * data, - size_t len, - int32_t n_embd, - int32_t il_start, - int32_t il_end) { - bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end); return res ? 0 : -1; } diff --git a/src/llama-context.h b/src/llama-context.h index d995117574..37117ba7b6 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -105,16 +105,11 @@ struct llama_context { void set_causal_attn(bool value); void set_warmup(bool value); - void set_adapter_lora( - llama_adapter_lora * adapter, - float scale); + void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - bool rm_adapter_lora( - llama_adapter_lora * adapter); + bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - void clear_adapter_lora(); - - bool apply_adapter_cvec( + bool set_adapter_cvec( const float * data, size_t len, int32_t n_embd,