From 31e4f189bbdb901a97ebf796a98049c5568379f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 11 Feb 2026 23:34:43 +0100 Subject: [PATCH] support for tensor dims % n_devs != 0 --- ggml/include/ggml-backend.h | 33 ++- ggml/src/ggml-backend-meta.cpp | 481 ++++++++++++++++++++------------- src/llama-model.cpp | 101 +++++++ src/llama-model.h | 11 + src/llama.cpp | 71 +---- 5 files changed, 438 insertions(+), 259 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 8de6f950b2..da7e1c1c0d 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -223,24 +223,31 @@ extern "C" { // Meta backend // - enum ggml_backend_meta_split_state { - // tensor split by tensor dimensions: - GGML_BACKEND_SPLIT_STATE_BY_NE0 = 0, - GGML_BACKEND_SPLIT_STATE_BY_NE1 = 1, - GGML_BACKEND_SPLIT_STATE_BY_NE2 = 2, - GGML_BACKEND_SPLIT_STATE_BY_NE3 = 3, + #define GGML_BACKEND_META_MAX_DEVICES 16 - GGML_BACKEND_SPLIT_STATE_MIRRORED = 10, // all values on all backends - GGML_BACKEND_SPLIT_STATE_PARTIAL = 11, // each backend has a partial sum + enum ggml_backend_meta_split_axis { + // tensor split by tensor dimensions: + GGML_BACKEND_SPLIT_AXIS_0 = 0, + GGML_BACKEND_SPLIT_AXIS_1 = 1, + GGML_BACKEND_SPLIT_AXIS_2 = 2, + GGML_BACKEND_SPLIT_AXIS_3 = 3, + + GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends + GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum // for internal bookkeeping only: - GGML_BACKEND_SPLIT_STATE_NONE = 98, - GGML_BACKEND_SPLIT_STATE_UNKNOWN = 99, + GGML_BACKEND_SPLIT_AXIS_NONE = 98, + GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99, + }; + GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis); + + struct ggml_backend_meta_split_state { + enum ggml_backend_meta_split_axis axis; + int64_t ne[GGML_BACKEND_META_MAX_DEVICES]; }; // function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible: - typedef enum ggml_backend_meta_split_state (*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata); - + typedef struct ggml_backend_meta_split_state (*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata); GGML_API bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev); GGML_API size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev); @@ -263,7 +270,7 @@ extern "C" { GGML_API size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend); GGML_API ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index); - GGML_API enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync); + GGML_API struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync); // temporary workaround to statically allocate tensors from a context in a deduplicated way: GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index c48ef18e71..364d064c21 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -20,6 +20,29 @@ struct ggml_backend_meta_buffer_type; struct ggml_backend_meta_buffer; struct ggml_backend_meta; +const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) { + switch (split_axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + return "0"; + case GGML_BACKEND_SPLIT_AXIS_1: + return "1"; + case GGML_BACKEND_SPLIT_AXIS_2: + return "2"; + case GGML_BACKEND_SPLIT_AXIS_3: + return "3"; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + return "MIRRORED"; + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: + return "PARTIAL"; + case GGML_BACKEND_SPLIT_AXIS_NONE: + return "NONE"; + case GGML_BACKEND_SPLIT_AXIS_UNKNOWN: + return "UNKNOWN"; + default: + GGML_ABORT("fatal error"); + } +} + // // meta backend device // @@ -351,6 +374,13 @@ struct ggml_backend_meta_buffer_context { buffer_config(ggml_context * ctx, ggml_backend_buffer_t buf) : ctx(ctx), buf(buf) {} }; std::vector buf_configs; + + int debug; + + ggml_backend_meta_buffer_context() { + const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG"); + debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0; + } }; static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) { @@ -374,32 +404,32 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(buffer); const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ true); - GGML_ASSERT(split_state != GGML_BACKEND_SPLIT_STATE_UNKNOWN); + GGML_ASSERT(split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); - int split_dim = split_state; + int split_dim = split_state.axis; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; for (size_t k = 0; k < GGML_MAX_DIMS; k++) { ne[k] = tensor->ne[k]; nb[k] = tensor->nb[k]; } - if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { - GGML_ASSERT(ne[split_dim] % (split_dim == 0 ? n_simple_bufs*ggml_blck_size(tensor->type) : n_simple_bufs) == 0); - ne[split_dim] /= n_simple_bufs; - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (tensor->nb[i] > tensor->nb[split_dim]) { - GGML_ASSERT(nb[i] % (n_simple_bufs*ggml_element_size(tensor)) == 0); - nb[i] /= n_simple_bufs; - } - } - } std::vector simple_tensors; - simple_tensors.reserve(buf_ctx->buf_configs.size()); - for (size_t j = 0; j < buf_ctx->buf_configs.size(); j++) { + simple_tensors.reserve(n_simple_bufs); + for (size_t j = 0; j < n_simple_bufs; j++) { ggml_context * simple_ctx = buf_ctx->buf_configs[j].ctx; ggml_backend_buffer_t simple_buf = buf_ctx->buf_configs[j].buf; + if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { + GGML_ASSERT(ggml_is_contiguously_allocated(tensor)); + ne[split_dim] = split_state.ne[j]; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->nb[i] > tensor->nb[split_dim]) { + nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim]; + } + } + } + ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne); t_ij->op = tensor->op; for (int i = 0; i < GGML_MAX_DIMS; i++) { @@ -444,12 +474,12 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - switch (split_state) { - case GGML_BACKEND_SPLIT_STATE_BY_NE0: - case GGML_BACKEND_SPLIT_STATE_BY_NE1: - case GGML_BACKEND_SPLIT_STATE_BY_NE2: { + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". - const size_t chunk_size_full = tensor->nb[int(split_state) + 1]; + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; GGML_ASSERT(offset % chunk_size_full == 0); GGML_ASSERT(size % chunk_size_full == 0); const int64_t i_start = offset /chunk_size_full; @@ -457,13 +487,13 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg size_t offset_j = 0; for (size_t j = 0; j < n_bufs; j++){ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t chunk_size_j = simple_tensor->nb[int(split_state) + 1]; + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; } GGML_ASSERT(offset_j == chunk_size_full); } break; - case GGML_BACKEND_SPLIT_STATE_MIRRORED: { + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { for (size_t j = 0; j < n_bufs; j++){ ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); ggml_backend_tensor_set(simple_tensor, data, offset, size); @@ -482,12 +512,12 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - switch (split_state) { - case GGML_BACKEND_SPLIT_STATE_BY_NE0: - case GGML_BACKEND_SPLIT_STATE_BY_NE1: - case GGML_BACKEND_SPLIT_STATE_BY_NE2: { + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". - const size_t chunk_size_full = tensor->nb[int(split_state) + 1]; + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; GGML_ASSERT(offset % chunk_size_full == 0); GGML_ASSERT(size % chunk_size_full == 0); const int64_t i_start = offset /chunk_size_full; @@ -495,13 +525,13 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co size_t offset_j = 0; for (size_t j = 0; j < n_bufs; j++){ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t chunk_size_j = simple_tensor->nb[int(split_state) + 1]; + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; } GGML_ASSERT(offset_j == chunk_size_full); } break; - case GGML_BACKEND_SPLIT_STATE_MIRRORED: { + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { // TODO other simple backend may be better const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); ggml_backend_tensor_get(simple_tensor, data, offset, size); @@ -578,7 +608,7 @@ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_bac /*.no_alloc =*/ true, }; - ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context; + ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(); size_t max_size = 0; buf_ctx->buf_configs.reserve(n_simple_bufts); for (size_t i = 0; i < n_simple_bufts; i++) { @@ -599,7 +629,7 @@ struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struc /*.no_alloc =*/ true, }; - ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context; + ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(); meta_buf_ctx->buf_configs.reserve(n_simple_bufts); for (size_t i = 0; i < n_simple_bufts; i++) { meta_buf_ctx->buf_configs.emplace_back(ggml_init(params), nullptr); @@ -723,12 +753,12 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - switch (split_state) { - case GGML_BACKEND_SPLIT_STATE_BY_NE0: - case GGML_BACKEND_SPLIT_STATE_BY_NE1: - case GGML_BACKEND_SPLIT_STATE_BY_NE2: { + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". - const size_t chunk_size_full = tensor->nb[int(split_state) + 1]; + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; GGML_ASSERT(offset % chunk_size_full == 0); GGML_ASSERT(size % chunk_size_full == 0); const int64_t i_start = offset /chunk_size_full; @@ -737,14 +767,14 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens for (size_t j = 0; j < n_backends; j++){ ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t chunk_size_j = simple_tensor->nb[int(split_state) + 1]; + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; } GGML_ASSERT(offset_j == chunk_size_full); } break; - case GGML_BACKEND_SPLIT_STATE_MIRRORED: { + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { for (size_t j = 0; j < n_backends; j++) { ggml_backend_tensor_set_async( ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size); @@ -763,12 +793,12 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - switch (split_state) { - case GGML_BACKEND_SPLIT_STATE_BY_NE0: - case GGML_BACKEND_SPLIT_STATE_BY_NE1: - case GGML_BACKEND_SPLIT_STATE_BY_NE2: { + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". - const size_t chunk_size_full = tensor->nb[int(split_state) + 1]; + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; GGML_ASSERT(offset % chunk_size_full == 0); GGML_ASSERT(size % chunk_size_full == 0); const int64_t i_start = offset /chunk_size_full; @@ -777,14 +807,14 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm for (size_t j = 0; j < n_backends; j++){ ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t chunk_size_j = simple_tensor->nb[int(split_state) + 1]; + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; } GGML_ASSERT(offset_j == chunk_size_full); } break; - case GGML_BACKEND_SPLIT_STATE_MIRRORED: { + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { // TODO other simple backend may be better ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0); const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); @@ -826,11 +856,11 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, int i_start = 0; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - const bool partial = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false) == GGML_BACKEND_SPLIT_STATE_PARTIAL; - if (partial) { + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); } - const bool new_subgraph = i + 1 == cgraph->n_nodes || partial; + const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; if (!new_subgraph) { continue; } @@ -1039,266 +1069,299 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz return backend_ctx->backend_configs[index].backend; } -enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { - GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); +struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool { + if (a.axis != b.axis) { + return false; + } + for (size_t j = 0; j < n_bufs; j++) { + if (a.ne[j] != b.ne[j]) { + return false; + } + } + return true; + }; + auto handle_generic = [&](const std::vector & src_split_states, bool scalar_only) -> ggml_backend_meta_split_state { - ggml_backend_meta_split_state homogeneous_src_split_state = GGML_BACKEND_SPLIT_STATE_NONE; + ggml_backend_meta_split_state homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}}; for (size_t i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { continue; } - if (homogeneous_src_split_state == GGML_BACKEND_SPLIT_STATE_NONE) { + if (homogeneous_src_split_state.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { homogeneous_src_split_state = src_split_states[i]; - } else if (src_split_states[i] != homogeneous_src_split_state) { - homogeneous_src_split_state = GGML_BACKEND_SPLIT_STATE_UNKNOWN; + } else if (!split_states_equal(src_split_states[i], homogeneous_src_split_state)) { + homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; + break; } } - if (homogeneous_src_split_state == GGML_BACKEND_SPLIT_STATE_NONE) { - homogeneous_src_split_state = GGML_BACKEND_SPLIT_STATE_UNKNOWN; + if (homogeneous_src_split_state.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { + homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; } - if (scalar_only && homogeneous_src_split_state >= 0 && homogeneous_src_split_state < GGML_MAX_DIMS) { - homogeneous_src_split_state = GGML_BACKEND_SPLIT_STATE_UNKNOWN; + if (scalar_only && homogeneous_src_split_state.axis >= 0 && homogeneous_src_split_state.axis < GGML_MAX_DIMS) { + homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; } - GGML_ASSERT(homogeneous_src_split_state != GGML_BACKEND_SPLIT_STATE_UNKNOWN); + GGML_ASSERT(homogeneous_src_split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); return homogeneous_src_split_state; }; // Some ops process data on a per-row bases: auto handle_per_row = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { - GGML_ASSERT(src_split_states[0] != GGML_BACKEND_SPLIT_STATE_BY_NE0); + GGML_ASSERT(src_split_states[0].axis != GGML_BACKEND_SPLIT_AXIS_0); return src_split_states[0]; }; // Some ops broadcast the src1 data across src0: auto handle_bin_bcast = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { - if (src_split_states[0] >= 0 && src_split_states[0] < GGML_MAX_DIMS && - tensor->src[1]->ne[int(src_split_states[0])] == 1 && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) { + if (src_split_states[0].axis >= 0 && src_split_states[0].axis < GGML_MAX_DIMS && + tensor->src[1]->ne[src_split_states[0].axis] == 1 && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { return src_split_states[0]; } - if (src_split_states[0] == src_split_states[1] && src_split_states[2] == GGML_BACKEND_SPLIT_STATE_MIRRORED) { + if (src_split_states[0].axis == src_split_states[1].axis && src_split_states[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { return src_split_states[0]; // GGML_ADD_ID } - GGML_ASSERT(tensor->src[2] == nullptr || src_split_states[2] == GGML_BACKEND_SPLIT_STATE_MIRRORED); + GGML_ASSERT(tensor->src[2] == nullptr || src_split_states[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); return handle_generic(src_split_states, /*scalar_only =*/ false); }; auto handle_mul_mat = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { - if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_MIRRORED && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) { - return GGML_BACKEND_SPLIT_STATE_MIRRORED; + if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}}; } - if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE1 && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED) { - return GGML_BACKEND_SPLIT_STATE_BY_NE0; + if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + ggml_backend_meta_split_state ret = src_split_states[0]; + ret.axis = GGML_BACKEND_SPLIT_AXIS_0; + return ret; } - if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE0 && src_split_states[1] == GGML_BACKEND_SPLIT_STATE_BY_NE0) { - return assume_sync ? GGML_BACKEND_SPLIT_STATE_MIRRORED : GGML_BACKEND_SPLIT_STATE_PARTIAL; + if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { + for (size_t j = 0; j < n_bufs; j++) { + GGML_ASSERT(src_split_states[0].ne[j] == src_split_states[1].ne[j]); + } + return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}}; } GGML_ABORT("fatal error"); - return GGML_BACKEND_SPLIT_STATE_UNKNOWN; + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; }; auto handle_reshape = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { - switch (src_split_states[0]) { - case GGML_BACKEND_SPLIT_STATE_BY_NE0: - case GGML_BACKEND_SPLIT_STATE_BY_NE1: - case GGML_BACKEND_SPLIT_STATE_BY_NE2: - case GGML_BACKEND_SPLIT_STATE_BY_NE3: { + switch (src_split_states[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { GGML_ASSERT(ggml_is_contiguous(tensor)); int64_t base_ne_in = 1; - for (int dim = 0; dim <= int(src_split_states[0]); dim++) { + for (int dim = 0; dim <= src_split_states[0].axis; dim++) { base_ne_in *= tensor->src[0]->ne[dim]; } int64_t base_ne_out = 1; for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim]; if (base_ne_out_next == base_ne_in) { - return ggml_backend_meta_split_state(dim); + return {ggml_backend_meta_split_axis(dim), {0}}; + } + if (base_ne_out_next > base_ne_in) { + GGML_ASSERT(dim + 1 < GGML_MAX_DIMS); + return {ggml_backend_meta_split_axis(dim + 1), {0}}; } base_ne_out = base_ne_out_next; } GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op)); } - case GGML_BACKEND_SPLIT_STATE_MIRRORED: - case GGML_BACKEND_SPLIT_STATE_PARTIAL: { + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { return src_split_states[0]; } default: { GGML_ABORT("fatal error"); - return GGML_BACKEND_SPLIT_STATE_UNKNOWN; + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; } } }; auto handle_view = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { - if (ggml_is_contiguous(tensor)) { + if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->view_src)) { return handle_reshape(src_split_states); } - if (src_split_states[0] == GGML_BACKEND_SPLIT_STATE_MIRRORED || src_split_states[0] == GGML_BACKEND_SPLIT_STATE_PARTIAL) { + if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { return src_split_states[0]; } - GGML_ABORT("non-contioguos view not implemented"); - return GGML_BACKEND_SPLIT_STATE_UNKNOWN; + GGML_ABORT("view of permuted tensor not implemented"); + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; }; auto handle_permute = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { - switch (src_split_states[0]) { - case GGML_BACKEND_SPLIT_STATE_BY_NE0: - case GGML_BACKEND_SPLIT_STATE_BY_NE1: - case GGML_BACKEND_SPLIT_STATE_BY_NE2: - case GGML_BACKEND_SPLIT_STATE_BY_NE3: { - return ggml_backend_meta_split_state(tensor->op_params[int(src_split_states[0])]); + switch (src_split_states[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { + return {ggml_backend_meta_split_axis(tensor->op_params[src_split_states[0].axis]), {0}}; } - case GGML_BACKEND_SPLIT_STATE_MIRRORED: - case GGML_BACKEND_SPLIT_STATE_PARTIAL: { + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { return src_split_states[0]; } default: { GGML_ABORT("fatal error"); - return GGML_BACKEND_SPLIT_STATE_UNKNOWN; + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; } } }; auto handle_set_rows = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { - GGML_ASSERT(src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE0); - GGML_ASSERT(src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED); - GGML_ASSERT(src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE0); + GGML_ASSERT(src_split_states[0].axis != GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(split_states_equal(src_split_states[0], src_split_states[2])); return src_split_states[0]; }; auto handle_rope = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { - GGML_ASSERT(src_split_states[1] == GGML_BACKEND_SPLIT_STATE_MIRRORED); + GGML_ASSERT(src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); return src_split_states[0]; }; auto handle_flash_attn_ext = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { - GGML_ASSERT( src_split_states[0] == GGML_BACKEND_SPLIT_STATE_BY_NE2); - GGML_ASSERT( src_split_states[1] == GGML_BACKEND_SPLIT_STATE_BY_NE2); - GGML_ASSERT( src_split_states[2] == GGML_BACKEND_SPLIT_STATE_BY_NE2); - GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[3] == GGML_BACKEND_SPLIT_STATE_MIRRORED); - GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[4] == GGML_BACKEND_SPLIT_STATE_BY_NE0); - return GGML_BACKEND_SPLIT_STATE_BY_NE1; + GGML_ASSERT( src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_split_states[2].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[4].axis == GGML_BACKEND_SPLIT_AXIS_0); + return {GGML_BACKEND_SPLIT_AXIS_1, {0}}; }; auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) { ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer)); const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context; - return dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud); + ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud); + if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) { + const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1; + int64_t ne_sum = 0; + for (size_t j = 0; j < n_bufs; j++) { + GGML_ASSERT(ret.ne[j] % granularity == 0); + ne_sum += ret.ne[j]; + } + GGML_ASSERT(ne_sum == tensor->ne[ret.axis]); + } + return ret; } - std::vector src_split_states(GGML_MAX_SRC, GGML_BACKEND_SPLIT_STATE_NONE); + std::vector src_split_states(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}}); for (size_t i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { - src_split_states[i] = GGML_BACKEND_SPLIT_STATE_UNKNOWN; + src_split_states[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; continue; } src_split_states[i] = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); } + ggml_backend_meta_split_state split_state; switch (tensor->op) { case GGML_OP_NONE: { - return GGML_BACKEND_SPLIT_STATE_MIRRORED; - } + split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}}; + } break; case GGML_OP_DUP: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_ADD: case GGML_OP_ADD_ID: { - return handle_bin_bcast(src_split_states); - } + split_state = handle_bin_bcast(src_split_states); + } break; case GGML_OP_ADD1: case GGML_OP_ACC: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: { - return handle_bin_bcast(src_split_states); - } + split_state = handle_bin_bcast(src_split_states); + } break; case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_LOG: case GGML_OP_SIN: case GGML_OP_COS: { - return handle_generic(src_split_states, /*scalar_only =*/ false); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; case GGML_OP_SUM: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_SUM_ROWS: case GGML_OP_CUMSUM: case GGML_OP_MEAN: case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: { - return handle_per_row(src_split_states); - } + split_state = handle_per_row(src_split_states); + } break; case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_CONCAT: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_SILU_BACK: { - return handle_generic(src_split_states, /*scalar_only =*/ false); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_GROUP_NORM: case GGML_OP_L2_NORM: { - return handle_per_row(src_split_states); - } + split_state = handle_per_row(src_split_states); + } break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { - return handle_mul_mat(src_split_states); - } + split_state = handle_mul_mat(src_split_states); + } break; case GGML_OP_OUT_PROD: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_SCALE: { - return handle_generic(src_split_states, /*scalar_only =*/ false); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; case GGML_OP_SET: case GGML_OP_CPY: case GGML_OP_CONT: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_RESHAPE: { - return handle_reshape(src_split_states); - } + split_state = handle_reshape(src_split_states); + } break; case GGML_OP_VIEW: { - return handle_view(src_split_states); - } + split_state = handle_view(src_split_states); + } break; case GGML_OP_PERMUTE: { - return handle_permute(src_split_states); - } + split_state = handle_permute(src_split_states); + } break; case GGML_OP_TRANSPOSE: case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS_BACK: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_SET_ROWS: { - return handle_set_rows(src_split_states); - } + split_state = handle_set_rows(src_split_states); + } break; case GGML_OP_DIAG: case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_ZERO: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: { - return handle_generic(src_split_states, /*scalar_only =*/ false); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; case GGML_OP_ROPE: { - return handle_rope(src_split_states); - } + split_state = handle_rope(src_split_states); + } break; case GGML_OP_ROPE_BACK: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_CLAMP: { - return handle_generic(src_split_states, /*scalar_only =*/ false); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_IM2COL: case GGML_OP_IM2COL_BACK: @@ -1316,22 +1379,22 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc case GGML_OP_ROLL: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_ARGSORT: case GGML_OP_TOP_K: { - return handle_per_row(src_split_states); - } + split_state = handle_per_row(src_split_states); + } break; case GGML_OP_LEAKY_RELU: { - return handle_generic(src_split_states, /*scalar_only =*/ false); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; case GGML_OP_TRI: case GGML_OP_FILL: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_FLASH_ATTN_EXT: { - return handle_flash_attn_ext(src_split_states); - } + split_state = handle_flash_attn_ext(src_split_states); + } break; case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: @@ -1343,45 +1406,97 @@ enum ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struc case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: case GGML_OP_SOLVE_TRI: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_UNARY: { - return handle_generic(src_split_states, /*scalar_only =*/ false); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; case GGML_OP_MAP_CUSTOM1: case GGML_OP_MAP_CUSTOM2: case GGML_OP_MAP_CUSTOM3: case GGML_OP_CUSTOM: { - return handle_generic(src_split_states, /*scalar_only =*/ true); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { - return handle_per_row(src_split_states); - } + split_state = handle_per_row(src_split_states); + } break; case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: case GGML_OP_GLU: { - return handle_generic(src_split_states, /*scalar_only =*/ false); - } + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; default: { GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op)); - return GGML_BACKEND_SPLIT_STATE_UNKNOWN; - } + split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; + } break; } + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + bool src_split_by_axis_found = false; + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || src_split_states[i].axis < 0 || src_split_states[i].axis >= GGML_MAX_DIMS) { + continue; + } + if (src_split_by_axis_found) { + for (size_t j = 0; j < n_bufs; j++) { + // Assert that ratio is consistent: + GGML_ASSERT( split_state.ne[j] * tensor->src[i]->ne[src_split_states[i].axis] + == src_split_states[i].ne[j] * tensor->ne[split_state.axis]); + } + } else { + for (size_t j = 0; j < n_bufs; j++) { + // Take over ratio from src: + split_state.ne[j] = src_split_states[i].ne[j] * tensor->ne[split_state.axis]; + GGML_ASSERT(split_state.ne[j] % tensor->src[i]->ne[src_split_states[i].axis] == 0); + split_state.ne[j] /= tensor->src[i]->ne[src_split_states[i].axis]; + } + } + src_split_by_axis_found = true; + } + GGML_ASSERT(src_split_by_axis_found); + } + return split_state; }; const std::pair key = std::make_pair(tensor, assume_sync); if (buf_ctx->split_state_cache.find(key) == buf_ctx->split_state_cache.end()) { buf_ctx->split_state_cache[key] = calculate_split_state(); + if (buf_ctx->debug > 0) { + std::string srcs_info; + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr) { + continue; + } + if (!srcs_info.empty()) { + srcs_info += ", "; + } + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true); + const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis); + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + ne_info += std::to_string(split_state.ne[j]); + } + srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]"; + } + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + ne_info += std::to_string(buf_ctx->split_state_cache[key].ne[j]); + } + GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op), + ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].axis), ne_info.c_str()); + } } ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key]; - GGML_ASSERT(ret != GGML_BACKEND_SPLIT_STATE_NONE); - if (assume_sync && ret == GGML_BACKEND_SPLIT_STATE_UNKNOWN) { - GGML_ABORT("fatal error"); - ret = GGML_BACKEND_SPLIT_STATE_MIRRORED; - } + GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE && ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); return ret; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9376ea5631..4aed50c903 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -26,6 +26,103 @@ #include #include +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) { + const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata; + + auto get_split_axis = [&]() -> ggml_backend_meta_split_axis { + // attention + const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).weight"); + if (std::regex_match(tensor->name, pattern_qkv_weight)) { + return GGML_BACKEND_SPLIT_AXIS_1; + } + const std::regex pattern_qkv_bias("blk\\.\\d*\\.attn_(q|k|v)\\.bias"); + if (std::regex_match(tensor->name, pattern_qkv_bias)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight"); + if (std::regex_match(tensor->name, pattern_qk_norm)) { + return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1; + } + const std::regex pattern_kv_cache("cache_(k|v)_l\\d*"); + const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight"); + if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight"); + if (std::regex_match(tensor->name, pattern_attn_out_weight)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias"); + if (std::regex_match(tensor->name, pattern_attn_out_bias)) { + return GGML_BACKEND_SPLIT_AXIS_MIRRORED; + } + + // FFN + const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight"); + if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) { + return GGML_BACKEND_SPLIT_AXIS_1; + } + const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias"); + if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight"); + if (std::regex_match(tensor->name, pattern_ffn_down_weight)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias"); + if (std::regex_match(tensor->name, pattern_ffn_down_bias)) { + return GGML_BACKEND_SPLIT_AXIS_MIRRORED; + } + + // output + const std::regex pattern_output_weight("output\\.weight"); + if (std::regex_match(tensor->name, pattern_output_weight)) { + return GGML_BACKEND_SPLIT_AXIS_1; + } + const std::regex pattern_output_bias("output\\.bias"); + if (std::regex_match(tensor->name, pattern_output_bias)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + + // everything else + return GGML_BACKEND_SPLIT_AXIS_MIRRORED; + }; + + ggml_backend_meta_split_state split_state; + split_state.axis = get_split_axis(); + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight"); + const int64_t granularity = std::regex_match(tensor->name, pattern_attn_sinks) ? 1 : 32; // TODO determine more generally + const int64_t ne_full = tensor->ne[get_split_axis()]; + GGML_ASSERT(ne_full % granularity == 0); + std::vector tensor_split_scan; + tensor_split_scan.reserve(ud->n_devices); + for (size_t j = 0; j < ud->n_devices; j++) { + tensor_split_scan.push_back(ud->tensor_split[j]); + if (j > 0) { + tensor_split_scan[j] += tensor_split_scan[j - 1]; + } + } + int64_t low = 0; + size_t j = 0; + for (; j < ud->n_devices - 1; j++) { + int64_t high = tensor_split_scan.back() == 0.0f ? + ne_full * (j+1)/ud->n_devices : ne_full * tensor_split_scan[j]/tensor_split_scan.back(); + if (high % granularity != 0) { + high -= high % granularity; + } + split_state.ne[j] = high - low; + low = high; + } + split_state.ne[j] = ne_full - low; + } else { + memset(split_state.ne, 0, sizeof(split_state.ne)); + } + return split_state; + GGML_UNUSED(userdata); +} + const char * llm_type_name(llm_type type) { switch (type) { case LLM_TYPE_14M: return "14M"; @@ -7610,6 +7707,10 @@ size_t llama_model::n_devices() const { return devices.size(); } +const float * llama_model::tensor_split() const { + return params.tensor_split; +} + uint32_t llama_model::n_gpu_layers() const { return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; } diff --git a/src/llama-model.h b/src/llama-model.h index adc8ff6479..c9ff9a991b 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -438,6 +438,13 @@ struct llama_layer { struct llama_layer_nextn nextn; }; +struct llama_meta_device_get_split_state_userdata { + size_t n_devices; + const float * tensor_split; +}; + +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata); + struct llama_model { llm_type type = LLM_TYPE_UNKNOWN; llm_arch arch = LLM_ARCH_UNKNOWN; @@ -498,6 +505,9 @@ struct llama_model { // for keeping track of associated LoRA adapters std::unordered_set loras; + // statically allocated context for assigning + struct llama_meta_device_get_split_state_userdata get_split_state_ud; + int64_t t_load_us = 0; int64_t t_start_us = 0; @@ -518,6 +528,7 @@ struct llama_model { size_t size() const; // file size size_t n_tensors() const; size_t n_devices() const; + const float * tensor_split() const; uint32_t n_gpu_layers() const; llama_split_mode split_mode() const; diff --git a/src/llama.cpp b/src/llama.cpp index 6e198fa901..bee9567352 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -884,67 +884,6 @@ static int llama_model_load(const std::string & fname, std::vector return 0; } -static enum ggml_backend_meta_split_state llama_meta_device_get_tensor_split(const struct ggml_tensor * tensor, void * userdata) { - // attention - const std::regex pattern_qkv_weight("blk\\.\\d*\\.attn_(q|k|v).weight"); - if (std::regex_match(tensor->name, pattern_qkv_weight)) { - return GGML_BACKEND_SPLIT_STATE_BY_NE1; - } - const std::regex pattern_qkv_bias("blk\\.\\d*\\.attn_(q|k|v)\\.bias"); - if (std::regex_match(tensor->name, pattern_qkv_bias)) { - return GGML_BACKEND_SPLIT_STATE_BY_NE0; - } - const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight"); - if (std::regex_match(tensor->name, pattern_qk_norm)) { - return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_STATE_MIRRORED : GGML_BACKEND_SPLIT_STATE_BY_NE1; - } - const std::regex pattern_kv_cache("cache_(k|v)_l\\d*"); - const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight"); - if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) { - return GGML_BACKEND_SPLIT_STATE_BY_NE0; - } - const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight"); - if (std::regex_match(tensor->name, pattern_attn_out_weight)) { - return GGML_BACKEND_SPLIT_STATE_BY_NE0; - } - const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias"); - if (std::regex_match(tensor->name, pattern_attn_out_bias)) { - return GGML_BACKEND_SPLIT_STATE_MIRRORED; - } - - // FFN - const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight"); - if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) { - return GGML_BACKEND_SPLIT_STATE_BY_NE1; - } - const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias"); - if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) { - return GGML_BACKEND_SPLIT_STATE_BY_NE0; - } - const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight"); - if (std::regex_match(tensor->name, pattern_ffn_down_weight)) { - return GGML_BACKEND_SPLIT_STATE_BY_NE0; - } - const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias"); - if (std::regex_match(tensor->name, pattern_ffn_down_bias)) { - return GGML_BACKEND_SPLIT_STATE_MIRRORED; - } - - // output - const std::regex pattern_output_weight("output\\.weight"); - if (std::regex_match(tensor->name, pattern_output_weight)) { - return GGML_BACKEND_SPLIT_STATE_BY_NE1; - } - const std::regex pattern_output_bias("output\\.bias"); - if (std::regex_match(tensor->name, pattern_output_bias)) { - return GGML_BACKEND_SPLIT_STATE_BY_NE0; - } - - // everything else - return GGML_BACKEND_SPLIT_STATE_MIRRORED; - GGML_UNUSED(userdata); -} - static struct llama_model * llama_model_load_from_file_impl( const std::string & path_model, std::vector & splits, @@ -982,7 +921,10 @@ static struct llama_model * llama_model_load_from_file_impl( while (params.devices[n_devs]) { n_devs++; } - model->devices.push_back(ggml_backend_meta_device(params.devices, n_devs, llama_meta_device_get_tensor_split, nullptr)); + model->get_split_state_ud.n_devices = n_devs; + model->get_split_state_ud.tensor_split = model->tensor_split(); + model->devices.push_back(ggml_backend_meta_device( + params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud)); } else { for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { model->devices.push_back(*dev); @@ -1004,7 +946,10 @@ static struct llama_model * llama_model_load_from_file_impl( } GGML_ASSERT(devs.size() >= 2); GGML_ASSERT(ggml_backend_dev_buffer_type(devs.back()) == ggml_backend_cpu_buffer_type()); - gpus.push_back(ggml_backend_meta_device(devs.data(), devs.size() - 1, llama_meta_device_get_tensor_split, nullptr)); + model->get_split_state_ud.n_devices = devs.size() - 1; + model->get_split_state_ud.tensor_split = model->tensor_split(); + gpus.push_back(ggml_backend_meta_device( + devs.data(), devs.size() - 1, llama_meta_device_get_split_state, &model->get_split_state_ud)); } else { for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { ggml_backend_dev_t dev = ggml_backend_dev_get(i);