Handle n_ctx 0 for models that entirely fit with n_ctx_train

This commit is contained in:
65a 2026-01-16 17:23:56 -08:00 committed by GitHub
parent a89002f07b
commit 5d770b9db8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 1 deletions

View File

@ -81,7 +81,11 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
throw std::runtime_error("failed to load model"); throw std::runtime_error("failed to load model");
} }
llama_context * ctx = llama_init_from_model(model, *cparams); llama_context_params cparams_copy = *cparams;
if (cparams_copy.n_ctx == 0)
cparams_copy.n_ctx = model->hparams.n_ctx_train;
llama_context * ctx = llama_init_from_model(model, cparams_copy);
if (ctx == nullptr) { if (ctx == nullptr) {
llama_model_free(model); llama_model_free(model);
llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); llama_log_set(ud.original_logger.callback, ud.original_logger.user_data);
@ -236,6 +240,8 @@ static void llama_params_fit_impl(
if (projected_free_per_device[0] >= margins[0]) { if (projected_free_per_device[0] >= margins[0]) {
LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
__func__, projected_free_per_device[0]/MiB, margins[0]/MiB); __func__, projected_free_per_device[0]/MiB, margins[0]/MiB);
if (cparams->n_ctx == 0)
cparams->n_ctx = hp_nct;
return; return;
} }
} else { } else {
@ -248,6 +254,8 @@ static void llama_params_fit_impl(
} }
if (!changes_needed) { if (!changes_needed) {
LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__);
if (cparams->n_ctx == 0)
cparams->n_ctx = hp_nct;
return; return;
} }
} }