llama-fit-params: fix step size for last device (#18415)
This commit is contained in:
parent
e59efe6a78
commit
f8d561eb87
|
|
@ -512,6 +512,9 @@ static void llama_params_fit_impl(
|
||||||
if (mem_high[id] > targets[id]) {
|
if (mem_high[id] > targets[id]) {
|
||||||
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
|
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
|
||||||
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
|
||||||
|
if (hp_nex > 0 && size_t(id) == nd - 1) {
|
||||||
|
delta--;
|
||||||
|
}
|
||||||
LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
|
LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
|
||||||
while (delta > 1) {
|
while (delta > 1) {
|
||||||
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
|
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue