Handle n_ctx 0 for models that entirely fit with n_ctx_train
This commit is contained in:
parent
a89002f07b
commit
5d770b9db8
|
|
@ -81,7 +81,11 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
|
|||
throw std::runtime_error("failed to load model");
|
||||
}
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, *cparams);
|
||||
llama_context_params cparams_copy = *cparams;
|
||||
if (cparams_copy.n_ctx == 0)
|
||||
cparams_copy.n_ctx = model->hparams.n_ctx_train;
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, cparams_copy);
|
||||
if (ctx == nullptr) {
|
||||
llama_model_free(model);
|
||||
llama_log_set(ud.original_logger.callback, ud.original_logger.user_data);
|
||||
|
|
@ -236,6 +240,8 @@ static void llama_params_fit_impl(
|
|||
if (projected_free_per_device[0] >= margins[0]) {
|
||||
LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
|
||||
__func__, projected_free_per_device[0]/MiB, margins[0]/MiB);
|
||||
if (cparams->n_ctx == 0)
|
||||
cparams->n_ctx = hp_nct;
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
|
|
@ -248,6 +254,8 @@ static void llama_params_fit_impl(
|
|||
}
|
||||
if (!changes_needed) {
|
||||
LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__);
|
||||
if (cparams->n_ctx == 0)
|
||||
cparams->n_ctx = hp_nct;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue