common, tools : refactor model loading to support backend samplers

This commit refactors the model loading process in common/common.cpp
to enable backend sampler to be configure prior to the llama_context
creation.

The motivation for this change is that just being able to set/reset the
backend samplers after the llama_context has been created will cause a
resize to occur in llama_context::output_reserve which we want to avoid.
This commit is contained in:
Daniel Bevenius 2025-11-21 14:26:52 +01:00
parent 61ffe41dc1
commit 9b2439347f
No known key found for this signature in database
3 changed files with 47 additions and 8 deletions

View File

@ -943,14 +943,26 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
// Model utils
//
struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
llama_model * common_load_model_from_params(common_params & params) {
auto mparams = common_model_params_to_llama(params);
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return nullptr;
}
return model;
}
struct common_init_result common_init_context_from_model(
llama_model * model,
common_params & params) {
common_init_result iparams;
if (model == NULL) {
LOG_ERR("%s: model is NULL\n", __func__);
return iparams;
}
@ -1125,6 +1137,14 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
struct common_init_result common_init_from_params(common_params & params) {
llama_model * model = common_load_model_from_params(params);
if (model == NULL) {
return common_init_result();
}
return common_init_context_from_model(model, params);
}
std::string get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.

View File

@ -640,6 +640,14 @@ struct common_init_result {
struct common_init_result common_init_from_params(common_params & params);
// Load model only (allows creating backend samplers before context initialization)
llama_model * common_load_model_from_params(common_params & params);
// Initialize context from an already-loaded model (allows pre-configuring backend samplers)
struct common_init_result common_init_context_from_model(
llama_model * model,
common_params & params);
struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);

View File

@ -137,18 +137,29 @@ int main(int argc, char ** argv) {
// load the model and apply lora adapter, if any
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
common_init_result llama_init = common_init_from_params(params);
model = llama_init.model.get();
ctx = llama_init.context.get();
model = common_load_model_from_params(params);
if (model == NULL) {
LOG_ERR("%s: error: unable to load model\n", __func__);
return 1;
}
// Configure backend sampler chain
llama_set_backend_sampler(ctx, 0, common_sampler_backend_init(model, sparams));
// Configure backend sampler if configured
llama_sampler * backend_sampler = common_sampler_backend_init(model, sparams);
if (backend_sampler) {
llama_sampler_seq_config sampler_config = { 0, backend_sampler };
params.backend_samplers = &sampler_config;
params.n_backend_samplers = 1;
}
common_init_result llama_init = common_init_context_from_model(model, params);
ctx = llama_init.context.get();
model = llama_init.model.get(); // Update pointer (now managed by llama_init)
if (ctx == NULL) {
LOG_ERR("%s: error: unable to create context\n", __func__);
return 1;
}
auto * mem = llama_get_memory(ctx);