llama-fit-params: fix overflow check (#18354)
This commit is contained in:
parent
026d2ad472
commit
a4bf35889e
|
|
@ -389,8 +389,8 @@ static void llama_params_fit_impl(
|
||||||
tensor_buft_overrides[itbo].buft = nullptr;
|
tensor_buft_overrides[itbo].buft = nullptr;
|
||||||
itbo++;
|
itbo++;
|
||||||
mparams.tensor_buft_overrides = tensor_buft_overrides;
|
mparams.tensor_buft_overrides = tensor_buft_overrides;
|
||||||
throw llama_params_fit_exception("llama_params_fit_n_tensor_buft_overrides() == "
|
throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == "
|
||||||
+ std::to_string(ntbo) + " is insufficient for model\n");
|
+ std::to_string(ntbo) + " is insufficient for model");
|
||||||
}
|
}
|
||||||
tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
|
tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
|
||||||
tensor_buft_overrides[itbo].buft = overflow_bufts[id];
|
tensor_buft_overrides[itbo].buft = overflow_bufts[id];
|
||||||
|
|
@ -647,7 +647,7 @@ static void llama_params_fit_impl(
|
||||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
|
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
|
||||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
|
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
|
||||||
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||||
if (mem_test[id] < targets[id]) {
|
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
|
||||||
ngl_per_device = ngl_per_device_test;
|
ngl_per_device = ngl_per_device_test;
|
||||||
mem = mem_test;
|
mem = mem_test;
|
||||||
id_dense_start = id_dense_start_test;
|
id_dense_start = id_dense_start_test;
|
||||||
|
|
@ -657,7 +657,7 @@ static void llama_params_fit_impl(
|
||||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
|
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
|
||||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
|
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
|
||||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||||
if (mem_test[id] < targets[id]) {
|
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
|
||||||
ngl_per_device = ngl_per_device_test;
|
ngl_per_device = ngl_per_device_test;
|
||||||
mem = mem_test;
|
mem = mem_test;
|
||||||
id_dense_start = id_dense_start_test;
|
id_dense_start = id_dense_start_test;
|
||||||
|
|
@ -668,7 +668,7 @@ static void llama_params_fit_impl(
|
||||||
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
|
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
|
||||||
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
|
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
|
||||||
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
|
||||||
if (mem_test[id] < targets[id]) {
|
if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
|
||||||
ngl_per_device = ngl_per_device_test;
|
ngl_per_device = ngl_per_device_test;
|
||||||
mem = mem_test;
|
mem = mem_test;
|
||||||
id_dense_start = id_dense_start_test;
|
id_dense_start = id_dense_start_test;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue