arbitrary num. of GPUs/tensor split
This commit is contained in:
parent
9c7d45c0fc
commit
98ab6727e4
|
|
@ -204,7 +204,7 @@ ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev,
|
|||
|
||||
ggml_backend_dev_t ggml_backend_meta_device(
|
||||
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) {
|
||||
GGML_ASSERT(n_devs == 1 || n_devs == 2 || n_devs == 4 || n_devs == 8);
|
||||
GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES);
|
||||
static std::vector<std::unique_ptr<ggml_backend_meta_device_context>> ctxs;
|
||||
static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs;
|
||||
|
||||
|
|
@ -441,7 +441,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
|
|||
ggml_set_name(t_ij, tensor->name);
|
||||
t_ij->buffer = simple_buf;
|
||||
t_ij->view_offs = tensor->view_offs;
|
||||
if (t_ij->view_offs > tensor->nb[split_dim]) {
|
||||
if (split_dim >= 0 && split_dim < GGML_MAX_DIMS && t_ij->view_offs > tensor->nb[split_dim]) {
|
||||
t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim];
|
||||
}
|
||||
t_ij->view_src = tensor->view_src;
|
||||
|
|
|
|||
|
|
@ -18,10 +18,12 @@
|
|||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
|
@ -29,58 +31,61 @@
|
|||
struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) {
|
||||
const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata;
|
||||
|
||||
const std::regex pattern_q_weight("blk\\.\\d*\\.attn_q.weight");
|
||||
const std::regex pattern_kv_weight("blk\\.\\d*\\.attn_(k|v).weight");
|
||||
const std::regex pattern_q_bias("blk\\.\\d*\\.attn_q\\.bias");
|
||||
const std::regex pattern_kv_bias("blk\\.\\d*\\.attn_(k|v)\\.bias");
|
||||
const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight");
|
||||
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
|
||||
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
|
||||
const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight");
|
||||
const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias");
|
||||
const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight");
|
||||
const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias");
|
||||
const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight");
|
||||
const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias");
|
||||
const std::regex pattern_output_weight("output\\.weight");
|
||||
const std::regex pattern_output_bias("output\\.bias");
|
||||
|
||||
auto get_split_axis = [&]() -> ggml_backend_meta_split_axis {
|
||||
// attention
|
||||
const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).weight");
|
||||
if (std::regex_match(tensor->name, pattern_qkv_weight)) {
|
||||
if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_kv_weight)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_1;
|
||||
}
|
||||
const std::regex pattern_qkv_bias("blk\\.\\d*\\.attn_(q|k|v)\\.bias");
|
||||
if (std::regex_match(tensor->name, pattern_qkv_bias)) {
|
||||
if (std::regex_match(tensor->name, pattern_q_bias) || std::regex_match(tensor->name, pattern_kv_bias)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight");
|
||||
if (std::regex_match(tensor->name, pattern_qk_norm)) {
|
||||
return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1;
|
||||
}
|
||||
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
|
||||
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
|
||||
if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight");
|
||||
if (std::regex_match(tensor->name, pattern_attn_out_weight)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias");
|
||||
if (std::regex_match(tensor->name, pattern_attn_out_bias)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
||||
}
|
||||
|
||||
// FFN
|
||||
const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_1;
|
||||
}
|
||||
const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_down_weight)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_down_bias)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
||||
}
|
||||
|
||||
// output
|
||||
const std::regex pattern_output_weight("output\\.weight");
|
||||
if (std::regex_match(tensor->name, pattern_output_weight)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_1;
|
||||
}
|
||||
const std::regex pattern_output_bias("output\\.bias");
|
||||
if (std::regex_match(tensor->name, pattern_output_bias)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
|
|
@ -89,17 +94,50 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
|
|||
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
||||
};
|
||||
|
||||
auto get_split_granularity = [&]() -> int64_t {
|
||||
// TODO determine this from tensors with AXIS_0
|
||||
constexpr int64_t blck_size = 32;
|
||||
|
||||
// attention
|
||||
if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_q_bias) ||
|
||||
std::regex_match(tensor->name, pattern_attn_out_weight)) {
|
||||
const uint32_t n_gqa = ud->model->hparams.n_gqa();
|
||||
const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k;
|
||||
return std::lcm(n_embd_q, blck_size);
|
||||
}
|
||||
if (std::regex_match(tensor->name, pattern_kv_weight) || std::regex_match(tensor->name, pattern_kv_bias) ||
|
||||
std::regex_match(tensor->name, pattern_kv_cache)) {
|
||||
const uint32_t n_gqa = ud->model->hparams.n_gqa();
|
||||
const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k;
|
||||
return std::lcm(n_embd_q, blck_size) / n_gqa;
|
||||
}
|
||||
if (std::regex_match(tensor->name, pattern_attn_sinks)) {
|
||||
const uint32_t n_gqa = ud->model->hparams.n_gqa();
|
||||
const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k;
|
||||
return std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa;
|
||||
}
|
||||
|
||||
// FFN
|
||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight) || std::regex_match(tensor->name, pattern_ffn_up_gate_bias) ||
|
||||
std::regex_match(tensor->name, pattern_ffn_down_weight)) {
|
||||
return blck_size;
|
||||
}
|
||||
|
||||
// everything else
|
||||
return 1;
|
||||
};
|
||||
|
||||
ggml_backend_meta_split_state split_state;
|
||||
split_state.axis = get_split_axis();
|
||||
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
|
||||
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
|
||||
const int64_t granularity = std::regex_match(tensor->name, pattern_attn_sinks) ? 1 : 32; // TODO determine more generally
|
||||
const int64_t ne_full = tensor->ne[get_split_axis()];
|
||||
const int64_t ne_full = tensor->ne[split_state.axis];
|
||||
const int64_t granularity = get_split_granularity();
|
||||
GGML_ASSERT(ne_full % granularity == 0);
|
||||
const float * tensor_split = ud->model->tensor_split();
|
||||
std::vector<float> tensor_split_scan;
|
||||
tensor_split_scan.reserve(ud->n_devices);
|
||||
for (size_t j = 0; j < ud->n_devices; j++) {
|
||||
tensor_split_scan.push_back(ud->tensor_split[j]);
|
||||
tensor_split_scan.push_back(tensor_split[j]);
|
||||
if (j > 0) {
|
||||
tensor_split_scan[j] += tensor_split_scan[j - 1];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -439,8 +439,8 @@ struct llama_layer {
|
|||
};
|
||||
|
||||
struct llama_meta_device_get_split_state_userdata {
|
||||
size_t n_devices;
|
||||
const float * tensor_split;
|
||||
size_t n_devices;
|
||||
const struct llama_model * model;
|
||||
};
|
||||
|
||||
struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
|
|
|||
|
|
@ -921,8 +921,8 @@ static struct llama_model * llama_model_load_from_file_impl(
|
|||
while (params.devices[n_devs]) {
|
||||
n_devs++;
|
||||
}
|
||||
model->get_split_state_ud.n_devices = n_devs;
|
||||
model->get_split_state_ud.tensor_split = model->tensor_split();
|
||||
model->get_split_state_ud.n_devices = n_devs;
|
||||
model->get_split_state_ud.model = model;
|
||||
model->devices.push_back(ggml_backend_meta_device(
|
||||
params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud));
|
||||
} else {
|
||||
|
|
@ -946,8 +946,8 @@ static struct llama_model * llama_model_load_from_file_impl(
|
|||
}
|
||||
GGML_ASSERT(devs.size() >= 2);
|
||||
GGML_ASSERT(ggml_backend_dev_buffer_type(devs.back()) == ggml_backend_cpu_buffer_type());
|
||||
model->get_split_state_ud.n_devices = devs.size() - 1;
|
||||
model->get_split_state_ud.tensor_split = model->tensor_split();
|
||||
model->get_split_state_ud.n_devices = devs.size() - 1;
|
||||
model->get_split_state_ud.model = model;
|
||||
gpus.push_back(ggml_backend_meta_device(
|
||||
devs.data(), devs.size() - 1, llama_meta_device_get_split_state, &model->get_split_state_ud));
|
||||
} else {
|
||||
|
|
|
|||
Loading…
Reference in New Issue