llama-fit-params: fix Gemma 3 calculation (#18372)

This commit is contained in:
Johannes Gäßler 2025-12-27 09:56:04 +01:00 committed by GitHub
parent c9ced4910b
commit 9045c9afe5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 25 additions and 20 deletions

View File

@ -181,12 +181,11 @@ static void llama_params_fit_impl(
} }
} }
int64_t sum_total = 0; int64_t sum_free = 0;
int64_t sum_projected_free = 0; int64_t sum_projected_free = 0;
int64_t min_projected_free = INT64_MAX; int64_t min_projected_free = INT64_MAX;
int64_t sum_projected_used = 0; int64_t sum_projected_used = 0;
int64_t sum_projected_model = 0; int64_t sum_projected_model = 0;
int64_t sum_projected_ctx = 0;
if (nd > 1) { if (nd > 1) {
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
@ -197,12 +196,11 @@ static void llama_params_fit_impl(
const int64_t projected_used = dmd.mb.total(); const int64_t projected_used = dmd.mb.total();
const int64_t projected_free = dmd.free - projected_used; const int64_t projected_free = dmd.free - projected_used;
sum_total += dmd.total; sum_free += dmd.free;
sum_projected_used += projected_used; sum_projected_used += projected_used;
sum_projected_free += projected_free; sum_projected_free += projected_free;
min_projected_free = std::min(min_projected_free, projected_free); min_projected_free = std::min(min_projected_free, projected_free);
sum_projected_model += dmd.mb.model; sum_projected_model += dmd.mb.model;
sum_projected_ctx += dmd.mb.context;
if (nd > 1) { if (nd > 1) {
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n", LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
@ -210,10 +208,9 @@ static void llama_params_fit_impl(
projected_free >= 0 ? "surplus" : "deficit"); projected_free >= 0 ? "surplus" : "deficit");
} }
} }
assert(sum_total >= 0 && sum_projected_used >= 0 && sum_projected_ctx >= 0); assert(sum_free >= 0 && sum_projected_used >= 0);
assert(sum_projected_used >= sum_projected_ctx);
LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n",
__func__, sum_projected_used/MiB, sum_total/MiB); __func__, sum_projected_used/MiB, sum_free/MiB);
if (min_projected_free >= margin) { if (min_projected_free >= margin) {
if (nd == 1) { if (nd == 1) {
LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
@ -236,9 +233,7 @@ static void llama_params_fit_impl(
__func__, margin/MiB, -global_surplus/MiB); __func__, margin/MiB, -global_surplus/MiB);
if (cparams->n_ctx == 0) { if (cparams->n_ctx == 0) {
if (hp_nct > n_ctx_min) { if (hp_nct > n_ctx_min) {
const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct; int64_t sum_used_target = sum_free - nd*margin_s;
int64_t memory_reduction = -global_surplus;
if (nd > 1) { if (nd > 1) {
// for multiple devices we need to be more conservative in terms of how much context we think can fit: // for multiple devices we need to be more conservative in terms of how much context we think can fit:
// - for dense models only whole layers can be assigned to devices // - for dense models only whole layers can be assigned to devices
@ -246,24 +241,34 @@ static void llama_params_fit_impl(
// - on average we expect a waste of 0.5 layers/tensors per device // - on average we expect a waste of 0.5 layers/tensors per device
// - use slightly more than the expected average for nd devices to be safe // - use slightly more than the expected average for nd devices to be safe
const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl); const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl);
memory_reduction += (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6); sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
} }
uint32_t ctx_reduction = std::min(uint32_t((memory_reduction + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min); int64_t sum_projected_used_min_ctx = 0;
cparams->n_ctx = hp_nct - ctx_reduction; cparams->n_ctx = n_ctx_min;
cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
for (const auto & dmd : dmds_min_ctx) {
sum_projected_used_min_ctx += dmd.mb.total();
}
if (sum_used_target > sum_projected_used_min_ctx) {
// linear interpolation between minimum and maximum context size:
cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx)
/ (sum_projected_used - sum_projected_used_min_ctx);
cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
ctx_reduction = hp_nct - cparams->n_ctx; const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min);
memory_reduction = ctx_reduction * bytes_per_ctx; const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx;
global_surplus += memory_reduction; LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
if (global_surplus >= 0) {
if (nd == 1) { if (nd == 1) {
LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__); LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__);
return; return;
} }
LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__); LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__);
} else {
const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx;
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
} }
} else { } else {
LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",