common, tools : refactor model loading to support backend samplers
This commit refactors the model loading process in common/common.cpp to enable backend sampler to be configure prior to the llama_context creation. The motivation for this change is that just being able to set/reset the backend samplers after the llama_context has been created will cause a resize to occur in llama_context::output_reserve which we want to avoid.
This commit is contained in:
parent
61ffe41dc1
commit
9b2439347f
|
|
@ -943,14 +943,26 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
|
|||
// Model utils
|
||||
//
|
||||
|
||||
struct common_init_result common_init_from_params(common_params & params) {
|
||||
common_init_result iparams;
|
||||
llama_model * common_load_model_from_params(common_params & params) {
|
||||
auto mparams = common_model_params_to_llama(params);
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
struct common_init_result common_init_context_from_model(
|
||||
llama_model * model,
|
||||
common_params & params) {
|
||||
common_init_result iparams;
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: model is NULL\n", __func__);
|
||||
return iparams;
|
||||
}
|
||||
|
||||
|
|
@ -1125,6 +1137,14 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
return iparams;
|
||||
}
|
||||
|
||||
struct common_init_result common_init_from_params(common_params & params) {
|
||||
llama_model * model = common_load_model_from_params(params);
|
||||
if (model == NULL) {
|
||||
return common_init_result();
|
||||
}
|
||||
return common_init_context_from_model(model, params);
|
||||
}
|
||||
|
||||
std::string get_model_endpoint() {
|
||||
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
||||
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
||||
|
|
|
|||
|
|
@ -640,6 +640,14 @@ struct common_init_result {
|
|||
|
||||
struct common_init_result common_init_from_params(common_params & params);
|
||||
|
||||
// Load model only (allows creating backend samplers before context initialization)
|
||||
llama_model * common_load_model_from_params(common_params & params);
|
||||
|
||||
// Initialize context from an already-loaded model (allows pre-configuring backend samplers)
|
||||
struct common_init_result common_init_context_from_model(
|
||||
llama_model * model,
|
||||
common_params & params);
|
||||
|
||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
|
||||
|
|
|
|||
|
|
@ -137,18 +137,29 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// load the model and apply lora adapter, if any
|
||||
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
|
||||
model = llama_init.model.get();
|
||||
ctx = llama_init.context.get();
|
||||
|
||||
model = common_load_model_from_params(params);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: error: unable to load model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Configure backend sampler chain
|
||||
llama_set_backend_sampler(ctx, 0, common_sampler_backend_init(model, sparams));
|
||||
// Configure backend sampler if configured
|
||||
llama_sampler * backend_sampler = common_sampler_backend_init(model, sparams);
|
||||
if (backend_sampler) {
|
||||
llama_sampler_seq_config sampler_config = { 0, backend_sampler };
|
||||
params.backend_samplers = &sampler_config;
|
||||
params.n_backend_samplers = 1;
|
||||
}
|
||||
|
||||
common_init_result llama_init = common_init_context_from_model(model, params);
|
||||
ctx = llama_init.context.get();
|
||||
model = llama_init.model.get(); // Update pointer (now managed by llama_init)
|
||||
|
||||
if (ctx == NULL) {
|
||||
LOG_ERR("%s: error: unable to create context\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue