This commit is contained in:
Daniel Andersen 2026-02-16 15:55:30 +02:00 committed by GitHub
commit 0306d99a83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 103 additions and 8 deletions

View File

@ -2331,11 +2331,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_N_GPU_LAYERS"));
add_opt(common_arg(
{"-sm", "--split-mode"}, "{none,layer,row}",
{"-sm", "--split-mode"}, "{none,layer,row,group}",
"how to split the model across multiple GPUs, one of:\n"
"- none: use one GPU only\n"
"- layer (default): split layers and KV across GPUs\n"
"- row: split rows across GPUs",
"- row: split rows across GPUs\n"
"- group: group GPUs to use minimum number needed based on available memory",
[](common_params & params, const std::string & value) {
std::string arg_next = value;
if (arg_next == "none") {
@ -2344,6 +2345,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
} else if (arg_next == "row") {
params.split_mode = LLAMA_SPLIT_MODE_ROW;
} else if (arg_next == "group") {
params.split_mode = LLAMA_SPLIT_MODE_GROUP;
} else {
throw std::invalid_argument("invalid value");
}

View File

@ -192,6 +192,7 @@ extern "C" {
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
LLAMA_SPLIT_MODE_GROUP = 3, // group GPUs to use minimum number needed based on available memory
};
// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)

View File

@ -21,6 +21,7 @@
#include <cstdio>
#include <cstring>
#include <ctime>
#include <filesystem>
#include <stdexcept>
#if defined(_MSC_VER)
@ -142,6 +143,87 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
return ret;
}
static std::vector<ggml_backend_dev_t> select_min_gpu_subset(
const std::vector<ggml_backend_dev_t> & available_gpus,
const char * path_model) {
// estimated runtime memory / file size (GGUF + dequant/overhead)
constexpr double MEMORY_ESTIMATE_RATIO = 1.5;
constexpr int64_t MiB = 1024*1024;
if (available_gpus.empty()) {
return available_gpus;
}
std::vector<ggml_backend_dev_t> gpu_devices;
for (auto dev : available_gpus) {
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
gpu_devices.push_back(dev);
}
}
if (gpu_devices.empty()) {
LLAMA_LOG_INFO("%s: no GPU devices found, using all devices\n", __func__);
return available_gpus;
}
std::vector<ggml_backend_dev_t> sorted_gpus = gpu_devices;
std::sort(sorted_gpus.begin(), sorted_gpus.end(), [](ggml_backend_dev_t a, ggml_backend_dev_t b) {
size_t free_a, total_a, free_b, total_b;
ggml_backend_dev_memory(a, &free_a, &total_a);
ggml_backend_dev_memory(b, &free_b, &total_b);
(void)total_a;
(void)total_b;
return free_a > free_b;
});
size_t file_size = 0;
try {
file_size = static_cast<size_t>(std::filesystem::file_size(path_model));
} catch (const std::exception & e) {
LLAMA_LOG_ERROR("%s: failed to get file size for '%s': %s\n", __func__, path_model, e.what());
LLAMA_LOG_INFO("%s: using all available devices as fallback\n", __func__);
return available_gpus;
} catch (...) {
LLAMA_LOG_ERROR("%s: failed to get file size for '%s': unknown error\n", __func__, path_model);
LLAMA_LOG_INFO("%s: using all available devices as fallback\n", __func__);
return available_gpus;
}
if (file_size == 0) {
LLAMA_LOG_ERROR("%s: model file '%s' appears to be empty\n", __func__, path_model);
LLAMA_LOG_INFO("%s: using all available devices as fallback\n", __func__);
return available_gpus;
}
size_t estimated_model_mem = static_cast<size_t>(file_size * MEMORY_ESTIMATE_RATIO);
LLAMA_LOG_DEBUG("%s: model file size: %zu MiB\n", __func__, file_size / MiB);
LLAMA_LOG_DEBUG("%s: estimated memory required: %zu MiB\n", __func__, estimated_model_mem / MiB);
std::vector<ggml_backend_dev_t> selected_gpus;
size_t cumulative_free = 0;
for (auto dev : sorted_gpus) {
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
(void)total;
selected_gpus.push_back(dev);
cumulative_free += free;
if (cumulative_free >= estimated_model_mem) {
LLAMA_LOG_DEBUG("%s: selected %zu device(s) for estimated %zu MiB model memory\n",
__func__, selected_gpus.size(), estimated_model_mem / MiB);
return selected_gpus;
}
}
LLAMA_LOG_DEBUG("%s: selected all %zu device(s) for estimated %zu MiB model memory\n",
__func__, selected_gpus.size(), estimated_model_mem / MiB);
if (cumulative_free < estimated_model_mem) {
LLAMA_LOG_WARN("%s: combined free memory (%zu MiB) is less than estimated model memory (%zu MiB)\n",
__func__, cumulative_free / MiB, estimated_model_mem / MiB);
LLAMA_LOG_WARN("%s: model load may fail or run out of memory\n", __func__);
}
return selected_gpus;
}
// enum to identify part of a layer for distributing its tensors:
enum layer_fraction_t {
LAYER_FRACTION_NONE = 0, // nothing
@ -978,6 +1060,11 @@ static struct llama_model * llama_model_load_from_file_impl(
}
}
// if using group mode, select minimum GPU subset based on free memory
if (params.split_mode == LLAMA_SPLIT_MODE_GROUP) {
model->devices = select_min_gpu_subset(model->devices, path_model.c_str());
}
// if using single GPU mode, remove all except the main GPU
if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
if (params.main_gpu < 0) {

View File

@ -66,7 +66,7 @@
| `-cmoe, --cpu-moe` | keep all Mixture of Experts (MoE) weights in the CPU<br/>(env: LLAMA_ARG_CPU_MOE) |
| `-ncmoe, --n-cpu-moe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU<br/>(env: LLAMA_ARG_N_CPU_MOE) |
| `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS) |
| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>(env: LLAMA_ARG_SPLIT_MODE) |
| `-sm, --split-mode {none,layer,row,group}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>- group: group GPUs to use minimum number needed based on available memory<br/>(env: LLAMA_ARG_SPLIT_MODE) |
| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1<br/>(env: LLAMA_ARG_TENSOR_SPLIT) |
| `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)<br/>(env: LLAMA_ARG_MAIN_GPU) |
| `-fit, --fit [on\|off]` | whether to adjust unset arguments to fit in device memory ('on' or 'off', default: 'on')<br/>(env: LLAMA_ARG_FIT) |

View File

@ -149,7 +149,7 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
| `-cmoe, --cpu-moe` | keep all Mixture of Experts (MoE) weights in the CPU<br/>(env: LLAMA_ARG_CPU_MOE) |
| `-ncmoe, --n-cpu-moe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU<br/>(env: LLAMA_ARG_N_CPU_MOE) |
| `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS) |
| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>(env: LLAMA_ARG_SPLIT_MODE) |
| `-sm, --split-mode {none,layer,row,group}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>- group: group GPUs to use minimum number needed based on available memory<br/>(env: LLAMA_ARG_SPLIT_MODE) |
| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1<br/>(env: LLAMA_ARG_TENSOR_SPLIT) |
| `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)<br/>(env: LLAMA_ARG_MAIN_GPU) |
| `-fit, --fit [on\|off]` | whether to adjust unset arguments to fit in device memory ('on' or 'off', default: 'on')<br/>(env: LLAMA_ARG_FIT) |

View File

@ -51,7 +51,7 @@ test parameters:
--poll <0...100> (default: 50)
-ngl, --n-gpu-layers <n> (default: 99)
-ncmoe, --n-cpu-moe <n> (default: 0)
-sm, --split-mode <none|layer|row> (default: layer)
-sm, --split-mode <none|layer|row|group> (default: layer)
-mg, --main-gpu <i> (default: 0)
-nkvo, --no-kv-offload <0|1> (default: 0)
-fa, --flash-attn <0|1> (default: 0)

View File

@ -259,6 +259,8 @@ static const char * split_mode_str(llama_split_mode mode) {
return "layer";
case LLAMA_SPLIT_MODE_ROW:
return "row";
case LLAMA_SPLIT_MODE_GROUP:
return "group";
default:
GGML_ABORT("invalid split mode");
}
@ -440,8 +442,8 @@ static void print_usage(int /* argc */, char ** argv) {
join(cmd_params_defaults.n_gpu_layers, ",").c_str());
printf(" -ncmoe, --n-cpu-moe <n> (default: %s)\n",
join(cmd_params_defaults.n_cpu_moe, ",").c_str());
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n",
join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
printf(" -sm, --split-mode <none|layer|row|group> (default: %s)\n",
join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
printf(" -mg, --main-gpu <i> (default: %s)\n",
join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n",
@ -723,6 +725,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
mode = LLAMA_SPLIT_MODE_LAYER;
} else if (m == "row") {
mode = LLAMA_SPLIT_MODE_ROW;
} else if (m == "group") {
mode = LLAMA_SPLIT_MODE_GROUP;
} else {
invalid_param = true;
break;

View File

@ -83,7 +83,7 @@ For the full list of features, please refer to [server's changelog](https://gith
| `-cmoe, --cpu-moe` | keep all Mixture of Experts (MoE) weights in the CPU<br/>(env: LLAMA_ARG_CPU_MOE) |
| `-ncmoe, --n-cpu-moe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU<br/>(env: LLAMA_ARG_N_CPU_MOE) |
| `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS) |
| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>(env: LLAMA_ARG_SPLIT_MODE) |
| `-sm, --split-mode {none,layer,row,group}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>- group: group GPUs to use minimum number needed based on available memory<br/>(env: LLAMA_ARG_SPLIT_MODE) |
| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1<br/>(env: LLAMA_ARG_TENSOR_SPLIT) |
| `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)<br/>(env: LLAMA_ARG_MAIN_GPU) |
| `-fit, --fit [on\|off]` | whether to adjust unset arguments to fit in device memory ('on' or 'off', default: 'on')<br/>(env: LLAMA_ARG_FIT) |