llama-fit-params: fix overflow check (#18354)

This commit is contained in:
Johannes Gäßler 2025-12-27 20:20:45 +01:00 committed by GitHub
parent 026d2ad472
commit a4bf35889e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 5 deletions

View File

@ -389,8 +389,8 @@ static void llama_params_fit_impl(
tensor_buft_overrides[itbo].buft = nullptr;
itbo++;
mparams.tensor_buft_overrides = tensor_buft_overrides;
throw llama_params_fit_exception("llama_params_fit_n_tensor_buft_overrides() == "
+ std::to_string(ntbo) + " is insufficient for model\n");
throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == "
+ std::to_string(ntbo) + " is insufficient for model");
}
tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
tensor_buft_overrides[itbo].buft = overflow_bufts[id];
@ -647,7 +647,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
if (mem_test[id] < targets[id]) {
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
@ -657,7 +657,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
if (mem_test[id] < targets[id]) {
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
@ -668,7 +668,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
if (mem_test[id] < targets[id]) {
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;