support for tensor dims % n_devs != 0
This commit is contained in:
parent
b12a56351d
commit
31e4f189bb
|
|
@ -223,24 +223,31 @@ extern "C" {
|
|||
// Meta backend
|
||||
//
|
||||
|
||||
enum ggml_backend_meta_split_state {
|
||||
// tensor split by tensor dimensions:
|
||||
GGML_BACKEND_SPLIT_STATE_BY_NE0 = 0,
|
||||
GGML_BACKEND_SPLIT_STATE_BY_NE1 = 1,
|
||||
GGML_BACKEND_SPLIT_STATE_BY_NE2 = 2,
|
||||
GGML_BACKEND_SPLIT_STATE_BY_NE3 = 3,
|
||||
#define GGML_BACKEND_META_MAX_DEVICES 16
|
||||
|
||||
GGML_BACKEND_SPLIT_STATE_MIRRORED = 10, // all values on all backends
|
||||
GGML_BACKEND_SPLIT_STATE_PARTIAL = 11, // each backend has a partial sum
|
||||
enum ggml_backend_meta_split_axis {
|
||||
// tensor split by tensor dimensions:
|
||||
GGML_BACKEND_SPLIT_AXIS_0 = 0,
|
||||
GGML_BACKEND_SPLIT_AXIS_1 = 1,
|
||||
GGML_BACKEND_SPLIT_AXIS_2 = 2,
|
||||
GGML_BACKEND_SPLIT_AXIS_3 = 3,
|
||||
|
||||
GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends
|
||||
GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum
|
||||
|
||||
// for internal bookkeeping only:
|
||||
GGML_BACKEND_SPLIT_STATE_NONE = 98,
|
||||
GGML_BACKEND_SPLIT_STATE_UNKNOWN = 99,
|
||||
GGML_BACKEND_SPLIT_AXIS_NONE = 98,
|
||||
GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99,
|
||||
};
|
||||
GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis);
|
||||
|
||||
struct ggml_backend_meta_split_state {
|
||||
enum ggml_backend_meta_split_axis axis;
|
||||
int64_t ne[GGML_BACKEND_META_MAX_DEVICES];
|
||||
};
|
||||
|
||||
// function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible:
|
||||
typedef enum ggml_backend_meta_split_state (*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
typedef struct ggml_backend_meta_split_state (*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
GGML_API bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev);
|
||||
GGML_API size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev);
|
||||
|
|
@ -263,7 +270,7 @@ extern "C" {
|
|||
GGML_API size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend);
|
||||
GGML_API ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index);
|
||||
|
||||
GGML_API enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync);
|
||||
GGML_API struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync);
|
||||
|
||||
// temporary workaround to statically allocate tensors from a context in a deduplicated way:
|
||||
GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
||||
|
|
|
|||
|
|
@ -20,6 +20,29 @@ struct ggml_backend_meta_buffer_type;
|
|||
struct ggml_backend_meta_buffer;
|
||||
struct ggml_backend_meta;
|
||||
|
||||
const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) {
|
||||
switch (split_axis) {
|
||||
case GGML_BACKEND_SPLIT_AXIS_0:
|
||||
return "0";
|
||||
case GGML_BACKEND_SPLIT_AXIS_1:
|
||||
return "1";
|
||||
case GGML_BACKEND_SPLIT_AXIS_2:
|
||||
return "2";
|
||||
case GGML_BACKEND_SPLIT_AXIS_3:
|
||||
return "3";
|
||||
case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
|
||||
return "MIRRORED";
|
||||
case GGML_BACKEND_SPLIT_AXIS_PARTIAL:
|
||||
return "PARTIAL";
|
||||
case GGML_BACKEND_SPLIT_AXIS_NONE:
|
||||
return "NONE";
|
||||
case GGML_BACKEND_SPLIT_AXIS_UNKNOWN:
|
||||
return "UNKNOWN";
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// meta backend device
|
||||
//
|
||||
|
|
@ -351,6 +374,13 @@ struct ggml_backend_meta_buffer_context {
|
|||
buffer_config(ggml_context * ctx, ggml_backend_buffer_t buf) : ctx(ctx), buf(buf) {}
|
||||
};
|
||||
std::vector<buffer_config> buf_configs;
|
||||
|
||||
int debug;
|
||||
|
||||
ggml_backend_meta_buffer_context() {
|
||||
const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG");
|
||||
debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0;
|
||||
}
|
||||
};
|
||||
|
||||
static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
|
|
@ -374,32 +404,32 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
|
|||
const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(buffer);
|
||||
|
||||
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ true);
|
||||
GGML_ASSERT(split_state != GGML_BACKEND_SPLIT_STATE_UNKNOWN);
|
||||
GGML_ASSERT(split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
|
||||
|
||||
int split_dim = split_state;
|
||||
int split_dim = split_state.axis;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
for (size_t k = 0; k < GGML_MAX_DIMS; k++) {
|
||||
ne[k] = tensor->ne[k];
|
||||
nb[k] = tensor->nb[k];
|
||||
}
|
||||
if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
|
||||
GGML_ASSERT(ne[split_dim] % (split_dim == 0 ? n_simple_bufs*ggml_blck_size(tensor->type) : n_simple_bufs) == 0);
|
||||
ne[split_dim] /= n_simple_bufs;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (tensor->nb[i] > tensor->nb[split_dim]) {
|
||||
GGML_ASSERT(nb[i] % (n_simple_bufs*ggml_element_size(tensor)) == 0);
|
||||
nb[i] /= n_simple_bufs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ggml_tensor *> simple_tensors;
|
||||
simple_tensors.reserve(buf_ctx->buf_configs.size());
|
||||
for (size_t j = 0; j < buf_ctx->buf_configs.size(); j++) {
|
||||
simple_tensors.reserve(n_simple_bufs);
|
||||
for (size_t j = 0; j < n_simple_bufs; j++) {
|
||||
ggml_context * simple_ctx = buf_ctx->buf_configs[j].ctx;
|
||||
ggml_backend_buffer_t simple_buf = buf_ctx->buf_configs[j].buf;
|
||||
|
||||
if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(tensor));
|
||||
ne[split_dim] = split_state.ne[j];
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (tensor->nb[i] > tensor->nb[split_dim]) {
|
||||
nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne);
|
||||
t_ij->op = tensor->op;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
|
|
@ -444,12 +474,12 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
|
|||
|
||||
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
|
||||
|
||||
switch (split_state) {
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE1:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE2: {
|
||||
switch (split_state.axis) {
|
||||
case GGML_BACKEND_SPLIT_AXIS_0:
|
||||
case GGML_BACKEND_SPLIT_AXIS_1:
|
||||
case GGML_BACKEND_SPLIT_AXIS_2: {
|
||||
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
|
||||
const size_t chunk_size_full = tensor->nb[int(split_state) + 1];
|
||||
const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
|
||||
GGML_ASSERT(offset % chunk_size_full == 0);
|
||||
GGML_ASSERT(size % chunk_size_full == 0);
|
||||
const int64_t i_start = offset /chunk_size_full;
|
||||
|
|
@ -457,13 +487,13 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg
|
|||
size_t offset_j = 0;
|
||||
for (size_t j = 0; j < n_bufs; j++){
|
||||
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
||||
const size_t chunk_size_j = simple_tensor->nb[int(split_state) + 1];
|
||||
const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
|
||||
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
|
||||
offset_j += chunk_size_j;
|
||||
}
|
||||
GGML_ASSERT(offset_j == chunk_size_full);
|
||||
} break;
|
||||
case GGML_BACKEND_SPLIT_STATE_MIRRORED: {
|
||||
case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
|
||||
for (size_t j = 0; j < n_bufs; j++){
|
||||
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
||||
ggml_backend_tensor_set(simple_tensor, data, offset, size);
|
||||
|
|
@ -482,12 +512,12 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co
|
|||
|
||||
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
|
||||
|
||||
switch (split_state) {
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE1:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE2: {
|
||||
switch (split_state.axis) {
|
||||
case GGML_BACKEND_SPLIT_AXIS_0:
|
||||
case GGML_BACKEND_SPLIT_AXIS_1:
|
||||
case GGML_BACKEND_SPLIT_AXIS_2: {
|
||||
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
|
||||
const size_t chunk_size_full = tensor->nb[int(split_state) + 1];
|
||||
const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
|
||||
GGML_ASSERT(offset % chunk_size_full == 0);
|
||||
GGML_ASSERT(size % chunk_size_full == 0);
|
||||
const int64_t i_start = offset /chunk_size_full;
|
||||
|
|
@ -495,13 +525,13 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co
|
|||
size_t offset_j = 0;
|
||||
for (size_t j = 0; j < n_bufs; j++){
|
||||
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
||||
const size_t chunk_size_j = simple_tensor->nb[int(split_state) + 1];
|
||||
const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
|
||||
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full);
|
||||
offset_j += chunk_size_j;
|
||||
}
|
||||
GGML_ASSERT(offset_j == chunk_size_full);
|
||||
} break;
|
||||
case GGML_BACKEND_SPLIT_STATE_MIRRORED: {
|
||||
case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
|
||||
// TODO other simple backend may be better
|
||||
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
|
||||
ggml_backend_tensor_get(simple_tensor, data, offset, size);
|
||||
|
|
@ -578,7 +608,7 @@ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_bac
|
|||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context;
|
||||
ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context();
|
||||
size_t max_size = 0;
|
||||
buf_ctx->buf_configs.reserve(n_simple_bufts);
|
||||
for (size_t i = 0; i < n_simple_bufts; i++) {
|
||||
|
|
@ -599,7 +629,7 @@ struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struc
|
|||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context;
|
||||
ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context();
|
||||
meta_buf_ctx->buf_configs.reserve(n_simple_bufts);
|
||||
for (size_t i = 0; i < n_simple_bufts; i++) {
|
||||
meta_buf_ctx->buf_configs.emplace_back(ggml_init(params), nullptr);
|
||||
|
|
@ -723,12 +753,12 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens
|
|||
|
||||
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
|
||||
|
||||
switch (split_state) {
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE1:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE2: {
|
||||
switch (split_state.axis) {
|
||||
case GGML_BACKEND_SPLIT_AXIS_0:
|
||||
case GGML_BACKEND_SPLIT_AXIS_1:
|
||||
case GGML_BACKEND_SPLIT_AXIS_2: {
|
||||
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
|
||||
const size_t chunk_size_full = tensor->nb[int(split_state) + 1];
|
||||
const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
|
||||
GGML_ASSERT(offset % chunk_size_full == 0);
|
||||
GGML_ASSERT(size % chunk_size_full == 0);
|
||||
const int64_t i_start = offset /chunk_size_full;
|
||||
|
|
@ -737,14 +767,14 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens
|
|||
for (size_t j = 0; j < n_backends; j++){
|
||||
ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
|
||||
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
||||
const size_t chunk_size_j = simple_tensor->nb[int(split_state) + 1];
|
||||
const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
|
||||
ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j,
|
||||
i_stop - i_start, chunk_size_j, chunk_size_full);
|
||||
offset_j += chunk_size_j;
|
||||
}
|
||||
GGML_ASSERT(offset_j == chunk_size_full);
|
||||
} break;
|
||||
case GGML_BACKEND_SPLIT_STATE_MIRRORED: {
|
||||
case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
|
||||
for (size_t j = 0; j < n_backends; j++) {
|
||||
ggml_backend_tensor_set_async(
|
||||
ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size);
|
||||
|
|
@ -763,12 +793,12 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm
|
|||
|
||||
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false);
|
||||
|
||||
switch (split_state) {
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE1:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE2: {
|
||||
switch (split_state.axis) {
|
||||
case GGML_BACKEND_SPLIT_AXIS_0:
|
||||
case GGML_BACKEND_SPLIT_AXIS_1:
|
||||
case GGML_BACKEND_SPLIT_AXIS_2: {
|
||||
// Exploit that tensors are contiguous to splice it with simple tensors as "chunks".
|
||||
const size_t chunk_size_full = tensor->nb[int(split_state) + 1];
|
||||
const size_t chunk_size_full = tensor->nb[split_state.axis + 1];
|
||||
GGML_ASSERT(offset % chunk_size_full == 0);
|
||||
GGML_ASSERT(size % chunk_size_full == 0);
|
||||
const int64_t i_start = offset /chunk_size_full;
|
||||
|
|
@ -777,14 +807,14 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm
|
|||
for (size_t j = 0; j < n_backends; j++){
|
||||
ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j);
|
||||
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
|
||||
const size_t chunk_size_j = simple_tensor->nb[int(split_state) + 1];
|
||||
const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1];
|
||||
ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j,
|
||||
i_stop - i_start, chunk_size_j, chunk_size_full);
|
||||
offset_j += chunk_size_j;
|
||||
}
|
||||
GGML_ASSERT(offset_j == chunk_size_full);
|
||||
} break;
|
||||
case GGML_BACKEND_SPLIT_STATE_MIRRORED: {
|
||||
case GGML_BACKEND_SPLIT_AXIS_MIRRORED: {
|
||||
// TODO other simple backend may be better
|
||||
ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0);
|
||||
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0);
|
||||
|
|
@ -826,11 +856,11 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
|
|||
int i_start = 0;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
const bool partial = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false) == GGML_BACKEND_SPLIT_STATE_PARTIAL;
|
||||
if (partial) {
|
||||
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false);
|
||||
if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
|
||||
max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node));
|
||||
}
|
||||
const bool new_subgraph = i + 1 == cgraph->n_nodes || partial;
|
||||
const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL;
|
||||
if (!new_subgraph) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -1039,266 +1069,299 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz
|
|||
return backend_ctx->backend_configs[index].backend;
|
||||
}
|
||||
|
||||
enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer));
|
||||
struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) {
|
||||
const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
|
||||
ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context;
|
||||
|
||||
auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool {
|
||||
if (a.axis != b.axis) {
|
||||
return false;
|
||||
}
|
||||
for (size_t j = 0; j < n_bufs; j++) {
|
||||
if (a.ne[j] != b.ne[j]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states, bool scalar_only) -> ggml_backend_meta_split_state {
|
||||
ggml_backend_meta_split_state homogeneous_src_split_state = GGML_BACKEND_SPLIT_STATE_NONE;
|
||||
ggml_backend_meta_split_state homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}};
|
||||
for (size_t i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
|
||||
continue;
|
||||
}
|
||||
if (homogeneous_src_split_state == GGML_BACKEND_SPLIT_STATE_NONE) {
|
||||
if (homogeneous_src_split_state.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
|
||||
homogeneous_src_split_state = src_split_states[i];
|
||||
} else if (src_split_states[i] != homogeneous_src_split_state) {
|
||||
homogeneous_src_split_state = GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
} else if (!split_states_equal(src_split_states[i], homogeneous_src_split_state)) {
|
||||
homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}};
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (homogeneous_src_split_state == GGML_BACKEND_SPLIT_STATE_NONE) {
|
||||
homogeneous_src_split_state = GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
if (homogeneous_src_split_state.axis == GGML_BACKEND_SPLIT_AXIS_NONE) {
|
||||
homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}};
|
||||
}
|
||||
if (scalar_only && homogeneous_src_split_state >= 0 && homogeneous_src_split_state < GGML_MAX_DIMS) {
|
||||
homogeneous_src_split_state = GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
if (scalar_only && homogeneous_src_split_state.axis >= 0 && homogeneous_src_split_state.axis < GGML_MAX_DIMS) {
|
||||
homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}};
|
||||
}
|
||||
GGML_ASSERT(homogeneous_src_split_state != GGML_BACKEND_SPLIT_STATE_UNKNOWN);
|
||||
GGML_ASSERT(homogeneous_src_split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
|
||||
return homogeneous_src_split_state;
|
||||
};
|
||||
|
||||
// Some ops process data on a per-row bases:
|
||||
auto handle_per_row = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
GGML_ASSERT(src_split_states[0] != GGML_BACKEND_SPLIT_STATE_BY_NE0);
|
||||
GGML_ASSERT(src_split_states[0].axis != GGML_BACKEND_SPLIT_AXIS_0);
|
||||
return src_split_states[0];
|
||||
};
|
||||
|
||||
// Some ops broadcast the src1 data across src0:
|
||||
auto handle_bin_bcast = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
if (src_split_states[0] >= 0 && src_split_states[0] < GGML_MAX_DIMS &&
|
||||
tensor->src[1]->ne[int(src_split_states[0])] == 1 && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) {
|
||||
if (src_split_states[0].axis >= 0 && src_split_states[0].axis < GGML_MAX_DIMS &&
|
||||
tensor->src[1]->ne[src_split_states[0].axis] == 1 && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
||||
return src_split_states[0];
|
||||
}
|
||||
if (src_split_states[0] == src_split_states[1] && src_split_states[2] == GGML_BACKEND_SPLIT_STATE_MIRRORED) {
|
||||
if (src_split_states[0].axis == src_split_states[1].axis && src_split_states[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
||||
return src_split_states[0]; // GGML_ADD_ID
|
||||
}
|
||||
GGML_ASSERT(tensor->src[2] == nullptr || src_split_states[2] == GGML_BACKEND_SPLIT_STATE_MIRRORED);
|
||||
GGML_ASSERT(tensor->src[2] == nullptr || src_split_states[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
};
|
||||
|
||||
auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_MIRRORED && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) {
|
||||
return GGML_BACKEND_SPLIT_STATE_MIRRORED;
|
||||
if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
||||
return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}};
|
||||
}
|
||||
if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE1 && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
|
||||
ggml_backend_meta_split_state ret = src_split_states[0];
|
||||
ret.axis = GGML_BACKEND_SPLIT_AXIS_0;
|
||||
return ret;
|
||||
}
|
||||
if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE0 && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_BY_NE0) {
|
||||
return assume_sync ? GGML_BACKEND_SPLIT_STATE_MIRRORED : GGML_BACKEND_SPLIT_STATE_PARTIAL;
|
||||
if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_0) {
|
||||
for (size_t j = 0; j < n_bufs; j++) {
|
||||
GGML_ASSERT(src_split_states[0].ne[j] == src_split_states[1].ne[j]);
|
||||
}
|
||||
return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}};
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
return GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}};
|
||||
};
|
||||
|
||||
auto handle_reshape = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
switch (src_split_states[0]) {
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE1:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE2:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE3: {
|
||||
switch (src_split_states[0].axis) {
|
||||
case GGML_BACKEND_SPLIT_AXIS_0:
|
||||
case GGML_BACKEND_SPLIT_AXIS_1:
|
||||
case GGML_BACKEND_SPLIT_AXIS_2:
|
||||
case GGML_BACKEND_SPLIT_AXIS_3: {
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor));
|
||||
int64_t base_ne_in = 1;
|
||||
for (int dim = 0; dim <= int(src_split_states[0]); dim++) {
|
||||
for (int dim = 0; dim <= src_split_states[0].axis; dim++) {
|
||||
base_ne_in *= tensor->src[0]->ne[dim];
|
||||
}
|
||||
int64_t base_ne_out = 1;
|
||||
for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
|
||||
const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim];
|
||||
if (base_ne_out_next == base_ne_in) {
|
||||
return ggml_backend_meta_split_state(dim);
|
||||
return {ggml_backend_meta_split_axis(dim), {0}};
|
||||
}
|
||||
if (base_ne_out_next > base_ne_in) {
|
||||
GGML_ASSERT(dim + 1 < GGML_MAX_DIMS);
|
||||
return {ggml_backend_meta_split_axis(dim + 1), {0}};
|
||||
}
|
||||
base_ne_out = base_ne_out_next;
|
||||
}
|
||||
GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op));
|
||||
}
|
||||
case GGML_BACKEND_SPLIT_STATE_MIRRORED:
|
||||
case GGML_BACKEND_SPLIT_STATE_PARTIAL: {
|
||||
case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
|
||||
case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
|
||||
return src_split_states[0];
|
||||
}
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
return GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
if (ggml_is_contiguous(tensor)) {
|
||||
if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->view_src)) {
|
||||
return handle_reshape(src_split_states);
|
||||
}
|
||||
if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_MIRRORED || src_split_states[0] == GGML_BACKEND_SPLIT_STATE_PARTIAL) {
|
||||
if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) {
|
||||
return src_split_states[0];
|
||||
}
|
||||
GGML_ABORT("non-contioguos view not implemented");
|
||||
return GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
GGML_ABORT("view of permuted tensor not implemented");
|
||||
return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}};
|
||||
};
|
||||
|
||||
auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
switch (src_split_states[0]) {
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE0:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE1:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE2:
|
||||
case GGML_BACKEND_SPLIT_STATE_BY_NE3: {
|
||||
return ggml_backend_meta_split_state(tensor->op_params[int(src_split_states[0])]);
|
||||
switch (src_split_states[0].axis) {
|
||||
case GGML_BACKEND_SPLIT_AXIS_0:
|
||||
case GGML_BACKEND_SPLIT_AXIS_1:
|
||||
case GGML_BACKEND_SPLIT_AXIS_2:
|
||||
case GGML_BACKEND_SPLIT_AXIS_3: {
|
||||
return {ggml_backend_meta_split_axis(tensor->op_params[src_split_states[0].axis]), {0}};
|
||||
}
|
||||
case GGML_BACKEND_SPLIT_STATE_MIRRORED:
|
||||
case GGML_BACKEND_SPLIT_STATE_PARTIAL: {
|
||||
case GGML_BACKEND_SPLIT_AXIS_MIRRORED:
|
||||
case GGML_BACKEND_SPLIT_AXIS_PARTIAL: {
|
||||
return src_split_states[0];
|
||||
}
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
return GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto handle_set_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
GGML_ASSERT(src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE0);
|
||||
GGML_ASSERT(src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED);
|
||||
GGML_ASSERT(src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE0);
|
||||
GGML_ASSERT(src_split_states[0].axis != GGML_BACKEND_SPLIT_AXIS_1);
|
||||
GGML_ASSERT(src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
|
||||
GGML_ASSERT(split_states_equal(src_split_states[0], src_split_states[2]));
|
||||
return src_split_states[0];
|
||||
};
|
||||
|
||||
auto handle_rope = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
GGML_ASSERT(src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED);
|
||||
GGML_ASSERT(src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
|
||||
return src_split_states[0];
|
||||
};
|
||||
|
||||
auto handle_flash_attn_ext = [&](const std::vector<ggml_backend_meta_split_state> & src_split_states) -> ggml_backend_meta_split_state {
|
||||
GGML_ASSERT( src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
|
||||
GGML_ASSERT( src_split_states[1] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
|
||||
GGML_ASSERT( src_split_states[2] == GGML_BACKEND_SPLIT_STATE_BY_NE2);
|
||||
GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[3] == GGML_BACKEND_SPLIT_STATE_MIRRORED);
|
||||
GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[4] == GGML_BACKEND_SPLIT_STATE_BY_NE0);
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE1;
|
||||
GGML_ASSERT( src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_2);
|
||||
GGML_ASSERT( src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_2);
|
||||
GGML_ASSERT( src_split_states[2].axis == GGML_BACKEND_SPLIT_AXIS_2);
|
||||
GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED);
|
||||
GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[4].axis == GGML_BACKEND_SPLIT_AXIS_0);
|
||||
return {GGML_BACKEND_SPLIT_AXIS_1, {0}};
|
||||
};
|
||||
|
||||
auto calculate_split_state = [&]() -> ggml_backend_meta_split_state {
|
||||
if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) {
|
||||
ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer));
|
||||
const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context;
|
||||
return dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud);
|
||||
ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud);
|
||||
if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) {
|
||||
const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1;
|
||||
int64_t ne_sum = 0;
|
||||
for (size_t j = 0; j < n_bufs; j++) {
|
||||
GGML_ASSERT(ret.ne[j] % granularity == 0);
|
||||
ne_sum += ret.ne[j];
|
||||
}
|
||||
GGML_ASSERT(ne_sum == tensor->ne[ret.axis]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::vector<ggml_backend_meta_split_state> src_split_states(GGML_MAX_SRC, GGML_BACKEND_SPLIT_STATE_NONE);
|
||||
std::vector<ggml_backend_meta_split_state> src_split_states(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}});
|
||||
for (size_t i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (tensor->src[i] == nullptr || tensor->src[i] == tensor) {
|
||||
src_split_states[i] = GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
src_split_states[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}};
|
||||
continue;
|
||||
}
|
||||
src_split_states[i] = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true);
|
||||
}
|
||||
|
||||
ggml_backend_meta_split_state split_state;
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_NONE: {
|
||||
return GGML_BACKEND_SPLIT_STATE_MIRRORED;
|
||||
}
|
||||
split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}};
|
||||
} break;
|
||||
case GGML_OP_DUP: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD_ID: {
|
||||
return handle_bin_bcast(src_split_states);
|
||||
}
|
||||
split_state = handle_bin_bcast(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ACC: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV: {
|
||||
return handle_bin_bcast(src_split_states);
|
||||
}
|
||||
split_state = handle_bin_bcast(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
} break;
|
||||
case GGML_OP_SUM: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_COUNT_EQUAL: {
|
||||
return handle_per_row(src_split_states);
|
||||
}
|
||||
split_state = handle_per_row(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_CONCAT: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_SILU_BACK: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
} break;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_L2_NORM: {
|
||||
return handle_per_row(src_split_states);
|
||||
}
|
||||
split_state = handle_per_row(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID: {
|
||||
return handle_mul_mat(src_split_states);
|
||||
}
|
||||
split_state = handle_mul_mat(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_OUT_PROD: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_SCALE: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_RESHAPE: {
|
||||
return handle_reshape(src_split_states);
|
||||
}
|
||||
split_state = handle_reshape(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_VIEW: {
|
||||
return handle_view(src_split_states);
|
||||
}
|
||||
split_state = handle_view(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_PERMUTE: {
|
||||
return handle_permute(src_split_states);
|
||||
}
|
||||
split_state = handle_permute(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_GET_ROWS_BACK: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_SET_ROWS: {
|
||||
return handle_set_rows(src_split_states);
|
||||
}
|
||||
split_state = handle_set_rows(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_DIAG_MASK_ZERO: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
} break;
|
||||
case GGML_OP_ROPE: {
|
||||
return handle_rope(src_split_states);
|
||||
}
|
||||
split_state = handle_rope(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_ROPE_BACK: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_CLAMP: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
} break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
|
|
@ -1316,22 +1379,22 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
|
|||
case GGML_OP_ROLL:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_TOP_K: {
|
||||
return handle_per_row(src_split_states);
|
||||
}
|
||||
split_state = handle_per_row(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
} break;
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_FILL: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
return handle_flash_attn_ext(src_split_states);
|
||||
}
|
||||
split_state = handle_flash_attn_ext(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
|
|
@ -1343,45 +1406,97 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc
|
|||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_SOLVE_TRI: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_UNARY: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
} break;
|
||||
case GGML_OP_MAP_CUSTOM1:
|
||||
case GGML_OP_MAP_CUSTOM2:
|
||||
case GGML_OP_MAP_CUSTOM3:
|
||||
case GGML_OP_CUSTOM: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ true);
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK: {
|
||||
return handle_per_row(src_split_states);
|
||||
}
|
||||
split_state = handle_per_row(src_split_states);
|
||||
} break;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
case GGML_OP_GLU: {
|
||||
return handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
}
|
||||
split_state = handle_generic(src_split_states, /*scalar_only =*/ false);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op));
|
||||
return GGML_BACKEND_SPLIT_STATE_UNKNOWN;
|
||||
}
|
||||
split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}};
|
||||
} break;
|
||||
}
|
||||
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
|
||||
bool src_split_by_axis_found = false;
|
||||
const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);
|
||||
|
||||
for (size_t i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (tensor->src[i] == nullptr || src_split_states[i].axis < 0 || src_split_states[i].axis >= GGML_MAX_DIMS) {
|
||||
continue;
|
||||
}
|
||||
if (src_split_by_axis_found) {
|
||||
for (size_t j = 0; j < n_bufs; j++) {
|
||||
// Assert that ratio is consistent:
|
||||
GGML_ASSERT( split_state.ne[j] * tensor->src[i]->ne[src_split_states[i].axis]
|
||||
== src_split_states[i].ne[j] * tensor->ne[split_state.axis]);
|
||||
}
|
||||
} else {
|
||||
for (size_t j = 0; j < n_bufs; j++) {
|
||||
// Take over ratio from src:
|
||||
split_state.ne[j] = src_split_states[i].ne[j] * tensor->ne[split_state.axis];
|
||||
GGML_ASSERT(split_state.ne[j] % tensor->src[i]->ne[src_split_states[i].axis] == 0);
|
||||
split_state.ne[j] /= tensor->src[i]->ne[src_split_states[i].axis];
|
||||
}
|
||||
}
|
||||
src_split_by_axis_found = true;
|
||||
}
|
||||
GGML_ASSERT(src_split_by_axis_found);
|
||||
}
|
||||
return split_state;
|
||||
};
|
||||
|
||||
const std::pair key = std::make_pair(tensor, assume_sync);
|
||||
|
||||
if (buf_ctx->split_state_cache.find(key) == buf_ctx->split_state_cache.end()) {
|
||||
buf_ctx->split_state_cache[key] = calculate_split_state();
|
||||
if (buf_ctx->debug > 0) {
|
||||
std::string srcs_info;
|
||||
for (size_t i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (tensor->src[i] == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (!srcs_info.empty()) {
|
||||
srcs_info += ", ";
|
||||
}
|
||||
const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true);
|
||||
const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis);
|
||||
std::string ne_info;
|
||||
for (size_t j = 0; j < n_bufs; j++) {
|
||||
if (!ne_info.empty()) {
|
||||
ne_info += ", ";
|
||||
}
|
||||
ne_info += std::to_string(split_state.ne[j]);
|
||||
}
|
||||
srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]";
|
||||
}
|
||||
std::string ne_info;
|
||||
for (size_t j = 0; j < n_bufs; j++) {
|
||||
if (!ne_info.empty()) {
|
||||
ne_info += ", ";
|
||||
}
|
||||
ne_info += std::to_string(buf_ctx->split_state_cache[key].ne[j]);
|
||||
}
|
||||
GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op),
|
||||
ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].axis), ne_info.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key];
|
||||
GGML_ASSERT(ret != GGML_BACKEND_SPLIT_STATE_NONE);
|
||||
if (assume_sync && ret == GGML_BACKEND_SPLIT_STATE_UNKNOWN) {
|
||||
GGML_ABORT("fatal error");
|
||||
ret = GGML_BACKEND_SPLIT_STATE_MIRRORED;
|
||||
}
|
||||
GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE && ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN);
|
||||
return ret;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,103 @@
|
|||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) {
|
||||
const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata;
|
||||
|
||||
auto get_split_axis = [&]() -> ggml_backend_meta_split_axis {
|
||||
// attention
|
||||
const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).weight");
|
||||
if (std::regex_match(tensor->name, pattern_qkv_weight)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_1;
|
||||
}
|
||||
const std::regex pattern_qkv_bias("blk\\.\\d*\\.attn_(q|k|v)\\.bias");
|
||||
if (std::regex_match(tensor->name, pattern_qkv_bias)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight");
|
||||
if (std::regex_match(tensor->name, pattern_qk_norm)) {
|
||||
return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1;
|
||||
}
|
||||
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
|
||||
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
|
||||
if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight");
|
||||
if (std::regex_match(tensor->name, pattern_attn_out_weight)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias");
|
||||
if (std::regex_match(tensor->name, pattern_attn_out_bias)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
||||
}
|
||||
|
||||
// FFN
|
||||
const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_1;
|
||||
}
|
||||
const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_down_weight)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_down_bias)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
||||
}
|
||||
|
||||
// output
|
||||
const std::regex pattern_output_weight("output\\.weight");
|
||||
if (std::regex_match(tensor->name, pattern_output_weight)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_1;
|
||||
}
|
||||
const std::regex pattern_output_bias("output\\.bias");
|
||||
if (std::regex_match(tensor->name, pattern_output_bias)) {
|
||||
return GGML_BACKEND_SPLIT_AXIS_0;
|
||||
}
|
||||
|
||||
// everything else
|
||||
return GGML_BACKEND_SPLIT_AXIS_MIRRORED;
|
||||
};
|
||||
|
||||
ggml_backend_meta_split_state split_state;
|
||||
split_state.axis = get_split_axis();
|
||||
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
|
||||
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
|
||||
const int64_t granularity = std::regex_match(tensor->name, pattern_attn_sinks) ? 1 : 32; // TODO determine more generally
|
||||
const int64_t ne_full = tensor->ne[get_split_axis()];
|
||||
GGML_ASSERT(ne_full % granularity == 0);
|
||||
std::vector<float> tensor_split_scan;
|
||||
tensor_split_scan.reserve(ud->n_devices);
|
||||
for (size_t j = 0; j < ud->n_devices; j++) {
|
||||
tensor_split_scan.push_back(ud->tensor_split[j]);
|
||||
if (j > 0) {
|
||||
tensor_split_scan[j] += tensor_split_scan[j - 1];
|
||||
}
|
||||
}
|
||||
int64_t low = 0;
|
||||
size_t j = 0;
|
||||
for (; j < ud->n_devices - 1; j++) {
|
||||
int64_t high = tensor_split_scan.back() == 0.0f ?
|
||||
ne_full * (j+1)/ud->n_devices : ne_full * tensor_split_scan[j]/tensor_split_scan.back();
|
||||
if (high % granularity != 0) {
|
||||
high -= high % granularity;
|
||||
}
|
||||
split_state.ne[j] = high - low;
|
||||
low = high;
|
||||
}
|
||||
split_state.ne[j] = ne_full - low;
|
||||
} else {
|
||||
memset(split_state.ne, 0, sizeof(split_state.ne));
|
||||
}
|
||||
return split_state;
|
||||
GGML_UNUSED(userdata);
|
||||
}
|
||||
|
||||
const char * llm_type_name(llm_type type) {
|
||||
switch (type) {
|
||||
case LLM_TYPE_14M: return "14M";
|
||||
|
|
@ -7610,6 +7707,10 @@ size_t llama_model::n_devices() const {
|
|||
return devices.size();
|
||||
}
|
||||
|
||||
const float * llama_model::tensor_split() const {
|
||||
return params.tensor_split;
|
||||
}
|
||||
|
||||
uint32_t llama_model::n_gpu_layers() const {
|
||||
return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -438,6 +438,13 @@ struct llama_layer {
|
|||
struct llama_layer_nextn nextn;
|
||||
};
|
||||
|
||||
struct llama_meta_device_get_split_state_userdata {
|
||||
size_t n_devices;
|
||||
const float * tensor_split;
|
||||
};
|
||||
|
||||
struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata);
|
||||
|
||||
struct llama_model {
|
||||
llm_type type = LLM_TYPE_UNKNOWN;
|
||||
llm_arch arch = LLM_ARCH_UNKNOWN;
|
||||
|
|
@ -498,6 +505,9 @@ struct llama_model {
|
|||
// for keeping track of associated LoRA adapters
|
||||
std::unordered_set<llama_adapter_lora *> loras;
|
||||
|
||||
// statically allocated context for assigning
|
||||
struct llama_meta_device_get_split_state_userdata get_split_state_ud;
|
||||
|
||||
int64_t t_load_us = 0;
|
||||
int64_t t_start_us = 0;
|
||||
|
||||
|
|
@ -518,6 +528,7 @@ struct llama_model {
|
|||
size_t size() const; // file size
|
||||
size_t n_tensors() const;
|
||||
size_t n_devices() const;
|
||||
const float * tensor_split() const;
|
||||
|
||||
uint32_t n_gpu_layers() const;
|
||||
llama_split_mode split_mode() const;
|
||||
|
|
|
|||
|
|
@ -884,67 +884,6 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
|
|||
return 0;
|
||||
}
|
||||
|
||||
static enum ggml_backend_meta_split_state llama_meta_device_get_tensor_split(const struct ggml_tensor * tensor, void * userdata) {
|
||||
// attention
|
||||
const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).weight");
|
||||
if (std::regex_match(tensor->name, pattern_qkv_weight)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE1;
|
||||
}
|
||||
const std::regex pattern_qkv_bias("blk\\.\\d*\\.attn_(q|k|v)\\.bias");
|
||||
if (std::regex_match(tensor->name, pattern_qkv_bias)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
}
|
||||
const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight");
|
||||
if (std::regex_match(tensor->name, pattern_qk_norm)) {
|
||||
return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_STATE_MIRRORED : GGML_BACKEND_SPLIT_STATE_BY_NE1;
|
||||
}
|
||||
const std::regex pattern_kv_cache("cache_(k|v)_l\\d*");
|
||||
const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight");
|
||||
if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
}
|
||||
const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight");
|
||||
if (std::regex_match(tensor->name, pattern_attn_out_weight)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
}
|
||||
const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias");
|
||||
if (std::regex_match(tensor->name, pattern_attn_out_bias)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_MIRRORED;
|
||||
}
|
||||
|
||||
// FFN
|
||||
const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE1;
|
||||
}
|
||||
const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
}
|
||||
const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_down_weight)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
}
|
||||
const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias");
|
||||
if (std::regex_match(tensor->name, pattern_ffn_down_bias)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_MIRRORED;
|
||||
}
|
||||
|
||||
// output
|
||||
const std::regex pattern_output_weight("output\\.weight");
|
||||
if (std::regex_match(tensor->name, pattern_output_weight)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE1;
|
||||
}
|
||||
const std::regex pattern_output_bias("output\\.bias");
|
||||
if (std::regex_match(tensor->name, pattern_output_bias)) {
|
||||
return GGML_BACKEND_SPLIT_STATE_BY_NE0;
|
||||
}
|
||||
|
||||
// everything else
|
||||
return GGML_BACKEND_SPLIT_STATE_MIRRORED;
|
||||
GGML_UNUSED(userdata);
|
||||
}
|
||||
|
||||
static struct llama_model * llama_model_load_from_file_impl(
|
||||
const std::string & path_model,
|
||||
std::vector<std::string> & splits,
|
||||
|
|
@ -982,7 +921,10 @@ static struct llama_model * llama_model_load_from_file_impl(
|
|||
while (params.devices[n_devs]) {
|
||||
n_devs++;
|
||||
}
|
||||
model->devices.push_back(ggml_backend_meta_device(params.devices, n_devs, llama_meta_device_get_tensor_split, nullptr));
|
||||
model->get_split_state_ud.n_devices = n_devs;
|
||||
model->get_split_state_ud.tensor_split = model->tensor_split();
|
||||
model->devices.push_back(ggml_backend_meta_device(
|
||||
params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud));
|
||||
} else {
|
||||
for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {
|
||||
model->devices.push_back(*dev);
|
||||
|
|
@ -1004,7 +946,10 @@ static struct llama_model * llama_model_load_from_file_impl(
|
|||
}
|
||||
GGML_ASSERT(devs.size() >= 2);
|
||||
GGML_ASSERT(ggml_backend_dev_buffer_type(devs.back()) == ggml_backend_cpu_buffer_type());
|
||||
gpus.push_back(ggml_backend_meta_device(devs.data(), devs.size() - 1, llama_meta_device_get_tensor_split, nullptr));
|
||||
model->get_split_state_ud.n_devices = devs.size() - 1;
|
||||
model->get_split_state_ud.tensor_split = model->tensor_split();
|
||||
gpus.push_back(ggml_backend_meta_device(
|
||||
devs.data(), devs.size() - 1, llama_meta_device_get_split_state, &model->get_split_state_ud));
|
||||
} else {
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
|
|
|
|||
Loading…
Reference in New Issue