From 98ab6727e474f4c518345a0668edd6b5cef0b2e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 13 Feb 2026 11:45:05 +0100 Subject: [PATCH] arbitrary num. of GPUs/tensor split --- ggml/src/ggml-backend-meta.cpp | 4 +- src/llama-model.cpp | 76 +++++++++++++++++++++++++--------- src/llama-model.h | 4 +- src/llama.cpp | 8 ++-- 4 files changed, 65 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 4a079c87bf..5b939fdf62 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -204,7 +204,7 @@ ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, ggml_backend_dev_t ggml_backend_meta_device( ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) { - GGML_ASSERT(n_devs == 1 || n_devs == 2 || n_devs == 4 || n_devs == 8); + GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES); static std::vector> ctxs; static std::map meta_devs; @@ -441,7 +441,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer ggml_set_name(t_ij, tensor->name); t_ij->buffer = simple_buf; t_ij->view_offs = tensor->view_offs; - if (t_ij->view_offs > tensor->nb[split_dim]) { + if (split_dim >= 0 && split_dim < GGML_MAX_DIMS && t_ij->view_offs > tensor->nb[split_dim]) { t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim]; } t_ij->view_src = tensor->view_src; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4aed50c903..5a01c5e065 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18,10 +18,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -29,58 +31,61 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) { const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata; + const std::regex pattern_q_weight("blk\\.\\d*\\.attn_q.weight"); + const std::regex pattern_kv_weight("blk\\.\\d*\\.attn_(k|v).weight"); + const std::regex pattern_q_bias("blk\\.\\d*\\.attn_q\\.bias"); + const std::regex pattern_kv_bias("blk\\.\\d*\\.attn_(k|v)\\.bias"); + const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight"); + const std::regex pattern_kv_cache("cache_(k|v)_l\\d*"); + const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight"); + const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight"); + const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias"); + const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight"); + const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias"); + const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight"); + const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias"); + const std::regex pattern_output_weight("output\\.weight"); + const std::regex pattern_output_bias("output\\.bias"); + auto get_split_axis = [&]() -> ggml_backend_meta_split_axis { // attention - const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).weight"); - if (std::regex_match(tensor->name, pattern_qkv_weight)) { + if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_kv_weight)) { return GGML_BACKEND_SPLIT_AXIS_1; } - const std::regex pattern_qkv_bias("blk\\.\\d*\\.attn_(q|k|v)\\.bias"); - if (std::regex_match(tensor->name, pattern_qkv_bias)) { + if (std::regex_match(tensor->name, pattern_q_bias) || std::regex_match(tensor->name, pattern_kv_bias)) { return GGML_BACKEND_SPLIT_AXIS_0; } - const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight"); if (std::regex_match(tensor->name, pattern_qk_norm)) { return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1; } - const std::regex pattern_kv_cache("cache_(k|v)_l\\d*"); - const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight"); if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) { return GGML_BACKEND_SPLIT_AXIS_0; } - const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight"); if (std::regex_match(tensor->name, pattern_attn_out_weight)) { return GGML_BACKEND_SPLIT_AXIS_0; } - const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias"); if (std::regex_match(tensor->name, pattern_attn_out_bias)) { return GGML_BACKEND_SPLIT_AXIS_MIRRORED; } // FFN - const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight"); if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) { return GGML_BACKEND_SPLIT_AXIS_1; } - const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias"); if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) { return GGML_BACKEND_SPLIT_AXIS_0; } - const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight"); if (std::regex_match(tensor->name, pattern_ffn_down_weight)) { return GGML_BACKEND_SPLIT_AXIS_0; } - const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias"); if (std::regex_match(tensor->name, pattern_ffn_down_bias)) { return GGML_BACKEND_SPLIT_AXIS_MIRRORED; } // output - const std::regex pattern_output_weight("output\\.weight"); if (std::regex_match(tensor->name, pattern_output_weight)) { return GGML_BACKEND_SPLIT_AXIS_1; } - const std::regex pattern_output_bias("output\\.bias"); if (std::regex_match(tensor->name, pattern_output_bias)) { return GGML_BACKEND_SPLIT_AXIS_0; } @@ -89,17 +94,50 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return GGML_BACKEND_SPLIT_AXIS_MIRRORED; }; + auto get_split_granularity = [&]() -> int64_t { + // TODO determine this from tensors with AXIS_0 + constexpr int64_t blck_size = 32; + + // attention + if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_q_bias) || + std::regex_match(tensor->name, pattern_attn_out_weight)) { + const uint32_t n_gqa = ud->model->hparams.n_gqa(); + const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k; + return std::lcm(n_embd_q, blck_size); + } + if (std::regex_match(tensor->name, pattern_kv_weight) || std::regex_match(tensor->name, pattern_kv_bias) || + std::regex_match(tensor->name, pattern_kv_cache)) { + const uint32_t n_gqa = ud->model->hparams.n_gqa(); + const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k; + return std::lcm(n_embd_q, blck_size) / n_gqa; + } + if (std::regex_match(tensor->name, pattern_attn_sinks)) { + const uint32_t n_gqa = ud->model->hparams.n_gqa(); + const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k; + return std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa; + } + + // FFN + if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight) || std::regex_match(tensor->name, pattern_ffn_up_gate_bias) || + std::regex_match(tensor->name, pattern_ffn_down_weight)) { + return blck_size; + } + + // everything else + return 1; + }; + ggml_backend_meta_split_state split_state; split_state.axis = get_split_axis(); if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { - const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight"); - const int64_t granularity = std::regex_match(tensor->name, pattern_attn_sinks) ? 1 : 32; // TODO determine more generally - const int64_t ne_full = tensor->ne[get_split_axis()]; + const int64_t ne_full = tensor->ne[split_state.axis]; + const int64_t granularity = get_split_granularity(); GGML_ASSERT(ne_full % granularity == 0); + const float * tensor_split = ud->model->tensor_split(); std::vector tensor_split_scan; tensor_split_scan.reserve(ud->n_devices); for (size_t j = 0; j < ud->n_devices; j++) { - tensor_split_scan.push_back(ud->tensor_split[j]); + tensor_split_scan.push_back(tensor_split[j]); if (j > 0) { tensor_split_scan[j] += tensor_split_scan[j - 1]; } diff --git a/src/llama-model.h b/src/llama-model.h index c9ff9a991b..fedded8585 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -439,8 +439,8 @@ struct llama_layer { }; struct llama_meta_device_get_split_state_userdata { - size_t n_devices; - const float * tensor_split; + size_t n_devices; + const struct llama_model * model; }; struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata); diff --git a/src/llama.cpp b/src/llama.cpp index bee9567352..1ea326107b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -921,8 +921,8 @@ static struct llama_model * llama_model_load_from_file_impl( while (params.devices[n_devs]) { n_devs++; } - model->get_split_state_ud.n_devices = n_devs; - model->get_split_state_ud.tensor_split = model->tensor_split(); + model->get_split_state_ud.n_devices = n_devs; + model->get_split_state_ud.model = model; model->devices.push_back(ggml_backend_meta_device( params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud)); } else { @@ -946,8 +946,8 @@ static struct llama_model * llama_model_load_from_file_impl( } GGML_ASSERT(devs.size() >= 2); GGML_ASSERT(ggml_backend_dev_buffer_type(devs.back()) == ggml_backend_cpu_buffer_type()); - model->get_split_state_ud.n_devices = devs.size() - 1; - model->get_split_state_ud.tensor_split = model->tensor_split(); + model->get_split_state_ud.n_devices = devs.size() - 1; + model->get_split_state_ud.model = model; gpus.push_back(ggml_backend_meta_device( devs.data(), devs.size() - 1, llama_meta_device_get_split_state, &model->get_split_state_ud)); } else {