diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 2304310bf0..9b7c5a2611 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -384,7 +384,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer nb[k] = tensor->nb[k]; } if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { - GGML_ASSERT(ne[split_dim] % (n_simple_bufs*ggml_blck_size(tensor->type)) == 0); + GGML_ASSERT(ne[split_dim] % (split_dim == 0 ? n_simple_bufs*ggml_blck_size(tensor->type) : n_simple_bufs) == 0); ne[split_dim] /= n_simple_bufs; for (int i = 0; i < GGML_MAX_DIMS; i++) { if (tensor->nb[i] > tensor->nb[split_dim]) { @@ -738,11 +738,44 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens } static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_ASSERT(ggml_backend_meta_get_split_state(tensor, false) == GGML_BACKEND_SPLIT_STATE_MIRRORED); const size_t n_backends = ggml_backend_meta_n_backends(backend); - GGML_ASSERT(n_backends >= 1); - ggml_backend_tensor_get_async( // TODO other backends may be more optimal - ggml_backend_meta_simple_backend(backend, 0), ggml_backend_meta_buffer_simple_tensor(tensor, 0), data, offset, size); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + + switch (split_state) { + case GGML_BACKEND_SPLIT_STATE_BY_NE0: + case GGML_BACKEND_SPLIT_STATE_BY_NE1: + case GGML_BACKEND_SPLIT_STATE_BY_NE2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[int(split_state) + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_backends; j++){ + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[int(split_state) + 1]; + for (int64_t i1 = i_start; i1 < i_stop; i1++) { + ggml_backend_tensor_get_async(simple_backend, simple_tensor, (char *) data + i1*chunk_size_full + offset_j, i1*chunk_size_j, chunk_size_j); + } + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_STATE_MIRRORED: { + // TODO other simple backend may be better + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); + ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } } static void ggml_backend_meta_synchronize(ggml_backend_t backend) { diff --git a/src/llama.cpp b/src/llama.cpp index 6f5d91c999..6e198fa901 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -931,10 +931,14 @@ static enum ggml_backend_meta_split_state llama_meta_device_get_tensor_split(con } // output - const std::regex pattern_output("output"); - if (std::regex_match(tensor->name, pattern_output)) { + const std::regex pattern_output_weight("output\\.weight"); + if (std::regex_match(tensor->name, pattern_output_weight)) { return GGML_BACKEND_SPLIT_STATE_BY_NE1; } + const std::regex pattern_output_bias("output\\.bias"); + if (std::regex_match(tensor->name, pattern_output_bias)) { + return GGML_BACKEND_SPLIT_STATE_BY_NE0; + } // everything else return GGML_BACKEND_SPLIT_STATE_MIRRORED;