diff --git a/common/common.cpp b/common/common.cpp index 8f2dfd8215..8e7cbfe1da 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -950,26 +950,14 @@ std::vector fs_list_files(const std::string & path) { // Model utils // -llama_model * common_load_model_from_params(common_params & params) { +struct common_init_result common_init_from_params(common_params & params) { + common_init_result iparams; auto mparams = common_model_params_to_llama(params); llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); if (model == NULL) { LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", __func__, params.model.path.c_str()); - return nullptr; - } - - return model; -} - -struct common_init_result common_init_context_from_model( - llama_model * model, - common_params & params) { - common_init_result iparams; - - if (model == NULL) { - LOG_ERR("%s: model is NULL\n", __func__); return iparams; } @@ -977,6 +965,16 @@ struct common_init_result common_init_context_from_model( auto cparams = common_context_params_to_llama(params); + // backend sampling initialization + if (params.sampling.backend_sampling) { + iparams.samplers_seq_config.resize(cparams.n_seq_max); + for (int i = 0; i < (int) cparams.n_seq_max; ++i) { + iparams.samplers_seq_config[i] = { i, common_sampler_backend_init(model, params.sampling) }; + } + cparams.samplers = iparams.samplers_seq_config.data(); + cparams.n_samplers = cparams.n_seq_max; + } + llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", @@ -1142,14 +1140,6 @@ struct common_init_result common_init_context_from_model( return iparams; } -struct common_init_result common_init_from_params(common_params & params) { - llama_model * model = common_load_model_from_params(params); - if (model == NULL) { - return common_init_result(); - } - return common_init_context_from_model(model, params); -} - std::string get_model_endpoint() { const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. @@ -1245,9 +1235,6 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.type_k = params.cache_type_k; cparams.type_v = params.cache_type_v; - cparams.samplers = params.backend_samplers; - cparams.n_samplers = params.n_backend_samplers; - return cparams; } diff --git a/common/common.h b/common/common.h index 01e6dfe59b..8c04715709 100644 --- a/common/common.h +++ b/common/common.h @@ -523,9 +523,6 @@ struct common_params { bool has_speculative() const { return !speculative.model.path.empty() || !speculative.model.hf_repo.empty(); } - - llama_sampler_seq_config * backend_samplers = NULL; - size_t n_backend_samplers = 0; }; // call once at the start of a program if it uses libcommon @@ -643,18 +640,13 @@ struct common_init_result { llama_context_ptr context; std::vector lora; + + std::vector samplers; + std::vector samplers_seq_config; }; struct common_init_result common_init_from_params(common_params & params); -// Load model only (allows creating backend samplers before context initialization) -llama_model * common_load_model_from_params(common_params & params); - -// Initialize context from an already-loaded model (allows pre-configuring backend samplers) -struct common_init_result common_init_context_from_model( - llama_model * model, - common_params & params); - struct llama_model_params common_model_params_to_llama ( common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params); struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params); diff --git a/tools/main/main.cpp b/tools/main/main.cpp index cae778e551..263387f417 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -138,22 +138,7 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); - model = common_load_model_from_params(params); - if (model == NULL) { - LOG_ERR("%s: error: unable to load model\n", __func__); - return 1; - } - - // Configure backend sampler if configured - llama_sampler * backend_sampler = common_sampler_backend_init(model, sparams); - llama_sampler_seq_config sampler_config = { 0, backend_sampler }; - - if (backend_sampler) { - params.backend_samplers = &sampler_config; - params.n_backend_samplers = 1; - } - - common_init_result llama_init = common_init_context_from_model(model, params); + common_init_result llama_init = common_init_from_params(params); ctx = llama_init.context.get(); model = llama_init.model.get(); // Update pointer (now managed by llama_init)