diff --git a/common/common.cpp b/common/common.cpp index c31619ac36..40dea7fd4b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -943,14 +943,26 @@ std::vector fs_list_files(const std::string & path) { // Model utils // -struct common_init_result common_init_from_params(common_params & params) { - common_init_result iparams; +llama_model * common_load_model_from_params(common_params & params) { auto mparams = common_model_params_to_llama(params); llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); if (model == NULL) { LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", __func__, params.model.path.c_str()); + return nullptr; + } + + return model; +} + +struct common_init_result common_init_context_from_model( + llama_model * model, + common_params & params) { + common_init_result iparams; + + if (model == NULL) { + LOG_ERR("%s: model is NULL\n", __func__); return iparams; } @@ -1125,6 +1137,14 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } +struct common_init_result common_init_from_params(common_params & params) { + llama_model * model = common_load_model_from_params(params); + if (model == NULL) { + return common_init_result(); + } + return common_init_context_from_model(model, params); +} + std::string get_model_endpoint() { const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. diff --git a/common/common.h b/common/common.h index be34bcb78c..5d289a116b 100644 --- a/common/common.h +++ b/common/common.h @@ -640,6 +640,14 @@ struct common_init_result { struct common_init_result common_init_from_params(common_params & params); +// Load model only (allows creating backend samplers before context initialization) +llama_model * common_load_model_from_params(common_params & params); + +// Initialize context from an already-loaded model (allows pre-configuring backend samplers) +struct common_init_result common_init_context_from_model( + llama_model * model, + common_params & params); + struct llama_model_params common_model_params_to_llama ( common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params); struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params); diff --git a/tools/main/main.cpp b/tools/main/main.cpp index f1d0fd4b60..06185e47eb 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -137,18 +137,29 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); - common_init_result llama_init = common_init_from_params(params); - - model = llama_init.model.get(); - ctx = llama_init.context.get(); + model = common_load_model_from_params(params); if (model == NULL) { LOG_ERR("%s: error: unable to load model\n", __func__); return 1; } - // Configure backend sampler chain - llama_set_backend_sampler(ctx, 0, common_sampler_backend_init(model, sparams)); + // Configure backend sampler if configured + llama_sampler * backend_sampler = common_sampler_backend_init(model, sparams); + if (backend_sampler) { + llama_sampler_seq_config sampler_config = { 0, backend_sampler }; + params.backend_samplers = &sampler_config; + params.n_backend_samplers = 1; + } + + common_init_result llama_init = common_init_context_from_model(model, params); + ctx = llama_init.context.get(); + model = llama_init.model.get(); // Update pointer (now managed by llama_init) + + if (ctx == NULL) { + LOG_ERR("%s: error: unable to create context\n", __func__); + return 1; + } auto * mem = llama_get_memory(ctx);