diff --git a/src/llama.cpp b/src/llama.cpp index 6da90d6f1f..dfbea2e7e0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -81,7 +81,11 @@ static std::vector llama_get_device_memory_data( throw std::runtime_error("failed to load model"); } - llama_context * ctx = llama_init_from_model(model, *cparams); + llama_context_params cparams_copy = *cparams; + if (cparams_copy.n_ctx == 0) + cparams_copy.n_ctx = model->hparams.n_ctx_train; + + llama_context * ctx = llama_init_from_model(model, cparams_copy); if (ctx == nullptr) { llama_model_free(model); llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); @@ -236,6 +240,8 @@ static void llama_params_fit_impl( if (projected_free_per_device[0] >= margins[0]) { LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); + if (cparams->n_ctx == 0) + cparams->n_ctx = hp_nct; return; } } else { @@ -248,6 +254,8 @@ static void llama_params_fit_impl( } if (!changes_needed) { LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); + if (cparams->n_ctx == 0) + cparams->n_ctx = hp_nct; return; } }