arbitrary num. of GPUs/tensor split
This commit is contained in:
parent
9c7d45c0fc
commit
98ab6727e4
|
|
@ -204,7 +204,7 @@ ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev,
|
||||||
|
|
||||||
ggml_backend_dev_t ggml_backend_meta_device(
|
ggml_backend_dev_t ggml_backend_meta_device(
|
||||||
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) {
|
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) {
|
||||||
GGML_ASSERT(n_devs == 1 || n_devs == 2 || n_devs == 4 || n_devs == 8);
|
GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES);
|
||||||
static std::vector<std::unique_ptr<ggml_backend_meta_device_context>> ctxs;
|
static std::vector<std::unique_ptr<ggml_backend_meta_device_context>> ctxs;
|
||||||
static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs;
|
static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs;
|
||||||
|
|
||||||
|
|
@ -441,7 +441,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
|
||||||
ggml_set_name(t_ij, tensor->name);
|
ggml_set_name(t_ij, tensor->name);
|
||||||
t_ij->buffer = simple_buf;
|
t_ij->buffer = simple_buf;
|
||||||
t_ij->view_offs = tensor->view_offs;
|
t_ij->view_offs = tensor->view_offs;
|
||||||
if (t_ij->view_offs > tensor->nb[split_dim]) {
|
if (split_dim >= 0 && split_dim < GGML_MAX_DIMS && t_ij->view_offs > tensor->nb[split_dim]) {
|
||||||
t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim];
|
t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim];
|
||||||
}
|
}
|
||||||
t_ij->view_src = tensor->view_src;
|
t_ij->view_src = tensor->view_src;
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,12 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
|
#include <cstdint>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <numeric>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
@ -29,58 +31,61 @@
|
||||||
struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) {
|
struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) {
|
||||||
const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata;
|
const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata;
|
||||||
|
|
||||||
|
const std::regex pattern_q_weight("blk\\.\\d*\\.attn_q.weight");
|
||||||
|
const std::regex pattern_kv_weight("blk\\.\\d*\\.attn_(k|v).weight");
|
||||||
|
const std::regex pattern_q_bias("blk\\.\\d*\\.attn_q\\.bias");
|
||||||
|
const std::regex pattern_kv_bias("blk\\.\\d*\\.attn_(k|v)\\.bias");
|
||||||
|
const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight");
|
||||||
|
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
|
||||||
|
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
|
||||||
|
const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight");
|
||||||
|
const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias");
|
||||||
|
const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight");
|
||||||
|
const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias");
|
||||||
|
const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight");
|
||||||
|
const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias");
|
||||||
|
const std::regex pattern_output_weight("output\\.weight");
|
||||||
|
const std::regex pattern_output_bias("output\\.bias");
|
||||||
|
|
||||||
auto get_split_axis = [&]() -> ggml_backend_meta_split_axis {
|
auto get_split_axis = [&]() -> ggml_backend_meta_split_axis {
|
||||||
// attention
|
// attention
|
||||||
const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).weight");
|
if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_kv_weight)) {
|
||||||
if (std::regex_match(tensor->name, pattern_qkv_weight)) {
|
|
||||||
return GGML_BACKEND_SPLIT_AXIS_1;
|
return GGML_BACKEND_SPLIT_AXIS_1;
|
||||||
}
|
}
|
||||||
const std::regex pattern_qkv_bias("blk\\.\\d*\\.attn_(q|k|v)\\.bias");
|
if (std::regex_match(tensor->name, pattern_q_bias) || std::regex_match(tensor->name, pattern_kv_bias)) {
|
||||||
if (std::regex_match(tensor->name, pattern_qkv_bias)) {
|
|
||||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||||
}
|
}
|
||||||
const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight");
|
|
||||||
if (std::regex_match(tensor->name, pattern_qk_norm)) {
|
if (std::regex_match(tensor->name, pattern_qk_norm)) {
|
||||||
return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1;
|
return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1;
|
||||||
}
|
}
|
||||||
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
|
|
||||||
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
|
|
||||||
if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) {
|
if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) {
|
||||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||||
}
|
}
|
||||||
const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight");
|
|
||||||
if (std::regex_match(tensor->name, pattern_attn_out_weight)) {
|
if (std::regex_match(tensor->name, pattern_attn_out_weight)) {
|
||||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||||
}
|
}
|
||||||
const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias");
|
|
||||||
if (std::regex_match(tensor->name, pattern_attn_out_bias)) {
|
if (std::regex_match(tensor->name, pattern_attn_out_bias)) {
|
||||||
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
||||||
}
|
}
|
||||||
|
|
||||||
// FFN
|
// FFN
|
||||||
const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight");
|
|
||||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) {
|
if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) {
|
||||||
return GGML_BACKEND_SPLIT_AXIS_1;
|
return GGML_BACKEND_SPLIT_AXIS_1;
|
||||||
}
|
}
|
||||||
const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias");
|
|
||||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) {
|
if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) {
|
||||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||||
}
|
}
|
||||||
const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight");
|
|
||||||
if (std::regex_match(tensor->name, pattern_ffn_down_weight)) {
|
if (std::regex_match(tensor->name, pattern_ffn_down_weight)) {
|
||||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||||
}
|
}
|
||||||
const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias");
|
|
||||||
if (std::regex_match(tensor->name, pattern_ffn_down_bias)) {
|
if (std::regex_match(tensor->name, pattern_ffn_down_bias)) {
|
||||||
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
||||||
}
|
}
|
||||||
|
|
||||||
// output
|
// output
|
||||||
const std::regex pattern_output_weight("output\\.weight");
|
|
||||||
if (std::regex_match(tensor->name, pattern_output_weight)) {
|
if (std::regex_match(tensor->name, pattern_output_weight)) {
|
||||||
return GGML_BACKEND_SPLIT_AXIS_1;
|
return GGML_BACKEND_SPLIT_AXIS_1;
|
||||||
}
|
}
|
||||||
const std::regex pattern_output_bias("output\\.bias");
|
|
||||||
if (std::regex_match(tensor->name, pattern_output_bias)) {
|
if (std::regex_match(tensor->name, pattern_output_bias)) {
|
||||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||||
}
|
}
|
||||||
|
|
@ -89,17 +94,50 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
|
||||||
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto get_split_granularity = [&]() -> int64_t {
|
||||||
|
// TODO determine this from tensors with AXIS_0
|
||||||
|
constexpr int64_t blck_size = 32;
|
||||||
|
|
||||||
|
// attention
|
||||||
|
if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_q_bias) ||
|
||||||
|
std::regex_match(tensor->name, pattern_attn_out_weight)) {
|
||||||
|
const uint32_t n_gqa = ud->model->hparams.n_gqa();
|
||||||
|
const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k;
|
||||||
|
return std::lcm(n_embd_q, blck_size);
|
||||||
|
}
|
||||||
|
if (std::regex_match(tensor->name, pattern_kv_weight) || std::regex_match(tensor->name, pattern_kv_bias) ||
|
||||||
|
std::regex_match(tensor->name, pattern_kv_cache)) {
|
||||||
|
const uint32_t n_gqa = ud->model->hparams.n_gqa();
|
||||||
|
const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k;
|
||||||
|
return std::lcm(n_embd_q, blck_size) / n_gqa;
|
||||||
|
}
|
||||||
|
if (std::regex_match(tensor->name, pattern_attn_sinks)) {
|
||||||
|
const uint32_t n_gqa = ud->model->hparams.n_gqa();
|
||||||
|
const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k;
|
||||||
|
return std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa;
|
||||||
|
}
|
||||||
|
|
||||||
|
// FFN
|
||||||
|
if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight) || std::regex_match(tensor->name, pattern_ffn_up_gate_bias) ||
|
||||||
|
std::regex_match(tensor->name, pattern_ffn_down_weight)) {
|
||||||
|
return blck_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// everything else
|
||||||
|
return 1;
|
||||||
|
};
|
||||||
|
|
||||||
ggml_backend_meta_split_state split_state;
|
ggml_backend_meta_split_state split_state;
|
||||||
split_state.axis = get_split_axis();
|
split_state.axis = get_split_axis();
|
||||||
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
|
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
|
||||||
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
|
const int64_t ne_full = tensor->ne[split_state.axis];
|
||||||
const int64_t granularity = std::regex_match(tensor->name, pattern_attn_sinks) ? 1 : 32; // TODO determine more generally
|
const int64_t granularity = get_split_granularity();
|
||||||
const int64_t ne_full = tensor->ne[get_split_axis()];
|
|
||||||
GGML_ASSERT(ne_full % granularity == 0);
|
GGML_ASSERT(ne_full % granularity == 0);
|
||||||
|
const float * tensor_split = ud->model->tensor_split();
|
||||||
std::vector<float> tensor_split_scan;
|
std::vector<float> tensor_split_scan;
|
||||||
tensor_split_scan.reserve(ud->n_devices);
|
tensor_split_scan.reserve(ud->n_devices);
|
||||||
for (size_t j = 0; j < ud->n_devices; j++) {
|
for (size_t j = 0; j < ud->n_devices; j++) {
|
||||||
tensor_split_scan.push_back(ud->tensor_split[j]);
|
tensor_split_scan.push_back(tensor_split[j]);
|
||||||
if (j > 0) {
|
if (j > 0) {
|
||||||
tensor_split_scan[j] += tensor_split_scan[j - 1];
|
tensor_split_scan[j] += tensor_split_scan[j - 1];
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -439,8 +439,8 @@ struct llama_layer {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_meta_device_get_split_state_userdata {
|
struct llama_meta_device_get_split_state_userdata {
|
||||||
size_t n_devices;
|
size_t n_devices;
|
||||||
const float * tensor_split;
|
const struct llama_model * model;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata);
|
struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata);
|
||||||
|
|
|
||||||
|
|
@ -921,8 +921,8 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||||
while (params.devices[n_devs]) {
|
while (params.devices[n_devs]) {
|
||||||
n_devs++;
|
n_devs++;
|
||||||
}
|
}
|
||||||
model->get_split_state_ud.n_devices = n_devs;
|
model->get_split_state_ud.n_devices = n_devs;
|
||||||
model->get_split_state_ud.tensor_split = model->tensor_split();
|
model->get_split_state_ud.model = model;
|
||||||
model->devices.push_back(ggml_backend_meta_device(
|
model->devices.push_back(ggml_backend_meta_device(
|
||||||
params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud));
|
params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud));
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -946,8 +946,8 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||||
}
|
}
|
||||||
GGML_ASSERT(devs.size() >= 2);
|
GGML_ASSERT(devs.size() >= 2);
|
||||||
GGML_ASSERT(ggml_backend_dev_buffer_type(devs.back()) == ggml_backend_cpu_buffer_type());
|
GGML_ASSERT(ggml_backend_dev_buffer_type(devs.back()) == ggml_backend_cpu_buffer_type());
|
||||||
model->get_split_state_ud.n_devices = devs.size() - 1;
|
model->get_split_state_ud.n_devices = devs.size() - 1;
|
||||||
model->get_split_state_ud.tensor_split = model->tensor_split();
|
model->get_split_state_ud.model = model;
|
||||||
gpus.push_back(ggml_backend_meta_device(
|
gpus.push_back(ggml_backend_meta_device(
|
||||||
devs.data(), devs.size() - 1, llama_meta_device_get_split_state, &model->get_split_state_ud));
|
devs.data(), devs.size() - 1, llama_meta_device_get_split_state, &model->get_split_state_ud));
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue