arbitrary num. of GPUs/tensor split

This commit is contained in:
Johannes Gäßler 2026-02-13 11:45:05 +01:00
parent 9c7d45c0fc
commit 98ab6727e4
4 changed files with 65 additions and 27 deletions

View File

@ -204,7 +204,7 @@ ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev,
ggml_backend_dev_t ggml_backend_meta_device(
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) {
GGML_ASSERT(n_devs == 1 || n_devs == 2 || n_devs == 4 || n_devs == 8);
GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES);
static std::vector<std::unique_ptr<ggml_backend_meta_device_context>> ctxs;
static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs;
@ -441,7 +441,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
ggml_set_name(t_ij, tensor->name);
t_ij->buffer = simple_buf;
t_ij->view_offs = tensor->view_offs;
if (t_ij->view_offs > tensor->nb[split_dim]) {
if (split_dim >= 0 && split_dim < GGML_MAX_DIMS && t_ij->view_offs > tensor->nb[split_dim]) {
t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim];
}
t_ij->view_src = tensor->view_src;

View File

@ -18,10 +18,12 @@
#include <algorithm>
#include <cassert>
#include <cfloat>
#include <cstdint>
#include <cstring>
#include <cmath>
#include <functional>
#include <map>
#include <numeric>
#include <regex>
#include <sstream>
#include <stdexcept>
@ -29,58 +31,61 @@
struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) {
const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata;
const std::regex pattern_q_weight("blk\\.\\d*\\.attn_q.weight");
const std::regex pattern_kv_weight("blk\\.\\d*\\.attn_(k|v).weight");
const std::regex pattern_q_bias("blk\\.\\d*\\.attn_q\\.bias");
const std::regex pattern_kv_bias("blk\\.\\d*\\.attn_(k|v)\\.bias");
const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight");
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight");
const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias");
const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight");
const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias");
const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight");
const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias");
const std::regex pattern_output_weight("output\\.weight");
const std::regex pattern_output_bias("output\\.bias");
auto get_split_axis = [&]() -> ggml_backend_meta_split_axis {
// attention
const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).weight");
if (std::regex_match(tensor->name, pattern_qkv_weight)) {
if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_kv_weight)) {
return GGML_BACKEND_SPLIT_AXIS_1;
}
const std::regex pattern_qkv_bias("blk\\.\\d*\\.attn_(q|k|v)\\.bias");
if (std::regex_match(tensor->name, pattern_qkv_bias)) {
if (std::regex_match(tensor->name, pattern_q_bias) || std::regex_match(tensor->name, pattern_kv_bias)) {
return GGML_BACKEND_SPLIT_AXIS_0;
}
const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight");
if (std::regex_match(tensor->name, pattern_qk_norm)) {
return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1;
}
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) {
return GGML_BACKEND_SPLIT_AXIS_0;
}
const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight");
if (std::regex_match(tensor->name, pattern_attn_out_weight)) {
return GGML_BACKEND_SPLIT_AXIS_0;
}
const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias");
if (std::regex_match(tensor->name, pattern_attn_out_bias)) {
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
}
// FFN
const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight");
if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) {
return GGML_BACKEND_SPLIT_AXIS_1;
}
const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias");
if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) {
return GGML_BACKEND_SPLIT_AXIS_0;
}
const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight");
if (std::regex_match(tensor->name, pattern_ffn_down_weight)) {
return GGML_BACKEND_SPLIT_AXIS_0;
}
const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias");
if (std::regex_match(tensor->name, pattern_ffn_down_bias)) {
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
}
// output
const std::regex pattern_output_weight("output\\.weight");
if (std::regex_match(tensor->name, pattern_output_weight)) {
return GGML_BACKEND_SPLIT_AXIS_1;
}
const std::regex pattern_output_bias("output\\.bias");
if (std::regex_match(tensor->name, pattern_output_bias)) {
return GGML_BACKEND_SPLIT_AXIS_0;
}
@ -89,17 +94,50 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
};
auto get_split_granularity = [&]() -> int64_t {
// TODO determine this from tensors with AXIS_0
constexpr int64_t blck_size = 32;
// attention
if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_q_bias) ||
std::regex_match(tensor->name, pattern_attn_out_weight)) {
const uint32_t n_gqa = ud->model->hparams.n_gqa();
const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k;
return std::lcm(n_embd_q, blck_size);
}
if (std::regex_match(tensor->name, pattern_kv_weight) || std::regex_match(tensor->name, pattern_kv_bias) ||
std::regex_match(tensor->name, pattern_kv_cache)) {
const uint32_t n_gqa = ud->model->hparams.n_gqa();
const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k;
return std::lcm(n_embd_q, blck_size) / n_gqa;
}
if (std::regex_match(tensor->name, pattern_attn_sinks)) {
const uint32_t n_gqa = ud->model->hparams.n_gqa();
const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k;
return std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa;
}
// FFN
if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight) || std::regex_match(tensor->name, pattern_ffn_up_gate_bias) ||
std::regex_match(tensor->name, pattern_ffn_down_weight)) {
return blck_size;
}
// everything else
return 1;
};
ggml_backend_meta_split_state split_state;
split_state.axis = get_split_axis();
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
const int64_t granularity = std::regex_match(tensor->name, pattern_attn_sinks) ? 1 : 32; // TODO determine more generally
const int64_t ne_full = tensor->ne[get_split_axis()];
const int64_t ne_full = tensor->ne[split_state.axis];
const int64_t granularity = get_split_granularity();
GGML_ASSERT(ne_full % granularity == 0);
const float * tensor_split = ud->model->tensor_split();
std::vector<float> tensor_split_scan;
tensor_split_scan.reserve(ud->n_devices);
for (size_t j = 0; j < ud->n_devices; j++) {
tensor_split_scan.push_back(ud->tensor_split[j]);
tensor_split_scan.push_back(tensor_split[j]);
if (j > 0) {
tensor_split_scan[j] += tensor_split_scan[j - 1];
}

View File

@ -439,8 +439,8 @@ struct llama_layer {
};
struct llama_meta_device_get_split_state_userdata {
size_t n_devices;
const float * tensor_split;
size_t n_devices;
const struct llama_model * model;
};
struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata);

View File

@ -921,8 +921,8 @@ static struct llama_model * llama_model_load_from_file_impl(
while (params.devices[n_devs]) {
n_devs++;
}
model->get_split_state_ud.n_devices = n_devs;
model->get_split_state_ud.tensor_split = model->tensor_split();
model->get_split_state_ud.n_devices = n_devs;
model->get_split_state_ud.model = model;
model->devices.push_back(ggml_backend_meta_device(
params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud));
} else {
@ -946,8 +946,8 @@ static struct llama_model * llama_model_load_from_file_impl(
}
GGML_ASSERT(devs.size() >= 2);
GGML_ASSERT(ggml_backend_dev_buffer_type(devs.back()) == ggml_backend_cpu_buffer_type());
model->get_split_state_ud.n_devices = devs.size() - 1;
model->get_split_state_ud.tensor_split = model->tensor_split();
model->get_split_state_ud.n_devices = devs.size() - 1;
model->get_split_state_ud.model = model;
gpus.push_back(ggml_backend_meta_device(
devs.data(), devs.size() - 1, llama_meta_device_get_split_state, &model->get_split_state_ud));
} else {