diff --git a/src/llama.cpp b/src/llama.cpp index 93a9c408ba..76b3acbadb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -512,6 +512,9 @@ static void llama_params_fit_impl( if (mem_high[id] > targets[id]) { assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; + if (hp_nex > 0 && size_t(id) == nd - 1) { + delta--; + } LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); while (delta > 1) { uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);