common, tools : refactor model loading to support backend samplers
This commit refactors the model loading process in common/common.cpp to enable backend sampler to be configure prior to the llama_context creation. The motivation for this change is that just being able to set/reset the backend samplers after the llama_context has been created will cause a resize to occur in llama_context::output_reserve which we want to avoid.
This commit is contained in:
parent
61ffe41dc1
commit
9b2439347f
|
|
@ -943,14 +943,26 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
|
||||||
// Model utils
|
// Model utils
|
||||||
//
|
//
|
||||||
|
|
||||||
struct common_init_result common_init_from_params(common_params & params) {
|
llama_model * common_load_model_from_params(common_params & params) {
|
||||||
common_init_result iparams;
|
|
||||||
auto mparams = common_model_params_to_llama(params);
|
auto mparams = common_model_params_to_llama(params);
|
||||||
|
|
||||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||||
if (model == NULL) {
|
if (model == NULL) {
|
||||||
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||||
__func__, params.model.path.c_str());
|
__func__, params.model.path.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct common_init_result common_init_context_from_model(
|
||||||
|
llama_model * model,
|
||||||
|
common_params & params) {
|
||||||
|
common_init_result iparams;
|
||||||
|
|
||||||
|
if (model == NULL) {
|
||||||
|
LOG_ERR("%s: model is NULL\n", __func__);
|
||||||
return iparams;
|
return iparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1125,6 +1137,14 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||||
return iparams;
|
return iparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct common_init_result common_init_from_params(common_params & params) {
|
||||||
|
llama_model * model = common_load_model_from_params(params);
|
||||||
|
if (model == NULL) {
|
||||||
|
return common_init_result();
|
||||||
|
}
|
||||||
|
return common_init_context_from_model(model, params);
|
||||||
|
}
|
||||||
|
|
||||||
std::string get_model_endpoint() {
|
std::string get_model_endpoint() {
|
||||||
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
||||||
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
||||||
|
|
|
||||||
|
|
@ -640,6 +640,14 @@ struct common_init_result {
|
||||||
|
|
||||||
struct common_init_result common_init_from_params(common_params & params);
|
struct common_init_result common_init_from_params(common_params & params);
|
||||||
|
|
||||||
|
// Load model only (allows creating backend samplers before context initialization)
|
||||||
|
llama_model * common_load_model_from_params(common_params & params);
|
||||||
|
|
||||||
|
// Initialize context from an already-loaded model (allows pre-configuring backend samplers)
|
||||||
|
struct common_init_result common_init_context_from_model(
|
||||||
|
llama_model * model,
|
||||||
|
common_params & params);
|
||||||
|
|
||||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||||
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
|
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
|
||||||
|
|
|
||||||
|
|
@ -137,18 +137,29 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// load the model and apply lora adapter, if any
|
// load the model and apply lora adapter, if any
|
||||||
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
|
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||||
common_init_result llama_init = common_init_from_params(params);
|
|
||||||
|
|
||||||
model = llama_init.model.get();
|
|
||||||
ctx = llama_init.context.get();
|
|
||||||
|
|
||||||
|
model = common_load_model_from_params(params);
|
||||||
if (model == NULL) {
|
if (model == NULL) {
|
||||||
LOG_ERR("%s: error: unable to load model\n", __func__);
|
LOG_ERR("%s: error: unable to load model\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure backend sampler chain
|
// Configure backend sampler if configured
|
||||||
llama_set_backend_sampler(ctx, 0, common_sampler_backend_init(model, sparams));
|
llama_sampler * backend_sampler = common_sampler_backend_init(model, sparams);
|
||||||
|
if (backend_sampler) {
|
||||||
|
llama_sampler_seq_config sampler_config = { 0, backend_sampler };
|
||||||
|
params.backend_samplers = &sampler_config;
|
||||||
|
params.n_backend_samplers = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_init_result llama_init = common_init_context_from_model(model, params);
|
||||||
|
ctx = llama_init.context.get();
|
||||||
|
model = llama_init.model.get(); // Update pointer (now managed by llama_init)
|
||||||
|
|
||||||
|
if (ctx == NULL) {
|
||||||
|
LOG_ERR("%s: error: unable to create context\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
auto * mem = llama_get_memory(ctx);
|
auto * mem = llama_get_memory(ctx);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue