diff --git a/common/arg.cpp b/common/arg.cpp index 18f953a38e..723fda71f3 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2331,19 +2331,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_N_GPU_LAYERS")); add_opt(common_arg( - {"-sm", "--split-mode"}, "{none,layer,row}", + {"-sm", "--split-mode"}, "{none,layer,row,tensor}", "how to split the model across multiple GPUs, one of:\n" "- none: use one GPU only\n" - "- layer (default): split layers and KV across GPUs\n" - "- row: split rows across GPUs", + "- layer (default): split layers and KV across GPUs (pipelined)\n" + "- row: split weight across GPUs by rows (parallelized)\n" + "- tensor: split weights and KV across GPUs (parallelized)", [](common_params & params, const std::string & value) { - std::string arg_next = value; - if (arg_next == "none") { + if (value == "none") { params.split_mode = LLAMA_SPLIT_MODE_NONE; - } else if (arg_next == "layer") { + } else if (value == "layer") { params.split_mode = LLAMA_SPLIT_MODE_LAYER; - } else if (arg_next == "row") { + } else if (value == "row") { params.split_mode = LLAMA_SPLIT_MODE_ROW; + } else if (value == "tensor") { + params.split_mode = LLAMA_SPLIT_MODE_TENSOR; } else { throw std::invalid_argument("invalid value"); } diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4323afe57b..e145a672ad 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -7,6 +7,8 @@ set(GGML_VERSION_MINOR 9) set(GGML_VERSION_PATCH 7) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") + find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) if(GIT_EXE) # Get current git commit hash @@ -203,12 +205,14 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) +option(GGML_CUDA_NCCL "ggml: use NVIDIA Collective Comm. Library" ON) set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING "ggml: cuda link binary compression mode; requires cuda 12.8+") set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size") option(GGML_HIP "ggml: use HIP" OFF) option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) diff --git a/ggml/cmake/FindNCCL.cmake b/ggml/cmake/FindNCCL.cmake new file mode 100644 index 0000000000..67511e2d56 --- /dev/null +++ b/ggml/cmake/FindNCCL.cmake @@ -0,0 +1,36 @@ +# cmake/FindNCCL.cmake + +# NVIDIA does not distribute CMake files with NCCl, therefore use this file to find it instead. + +find_path(NCCL_INCLUDE_DIR + NAMES nccl.h + HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda + PATH_SUFFIXES include +) + +find_library(NCCL_LIBRARY + NAMES nccl + HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda + PATH_SUFFIXES lib lib64 +) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL + DEFAULT_MSG + NCCL_LIBRARY NCCL_INCLUDE_DIR +) + +if(NCCL_FOUND) + set(NCCL_LIBRARIES ${NCCL_LIBRARY}) + set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR}) + + if(NOT TARGET NCCL::NCCL) + add_library(NCCL::NCCL UNKNOWN IMPORTED) + set_target_properties(NCCL::NCCL PROPERTIES + IMPORTED_LOCATION "${NCCL_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}" + ) + endif() +endif() + +mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index a9d1778641..da7e1c1c0d 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -68,7 +68,7 @@ extern "C" { GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); // tensor copy between different backends - GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst); // // Backend (stream) @@ -83,13 +83,17 @@ extern "C" { GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend); - GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_async (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get_async (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); // "offset" refers to the offset in tensor->data for setting/getting data - GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set ( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get (const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_2d( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); GGML_API void ggml_backend_synchronize(ggml_backend_t backend); @@ -109,7 +113,7 @@ extern "C" { // the copy is performed after all the currently queued operations in backend_src // backend_dst will wait for the copy to complete before performing other operations // automatic fallback to sync copy if async is not supported - GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend); @@ -135,7 +139,9 @@ extern "C" { // integrated GPU device using host memory GGML_BACKEND_DEVICE_TYPE_IGPU, // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) - GGML_BACKEND_DEVICE_TYPE_ACCEL + GGML_BACKEND_DEVICE_TYPE_ACCEL, + // "meta" device wrapping multiple other devices for tensor parallelism + GGML_BACKEND_DEVICE_TYPE_META, }; // functionality supported by the device @@ -196,7 +202,9 @@ extern "C" { // Common functions that may be obtained using ggml_backend_reg_get_proc_address - // Split buffer type for tensor parallelism + // AllReduce operation for tensor parallelism (meta backend) + typedef bool (*ggml_backend_allreduce_tensor_t)(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends); + // Split buffer type for tensor parallelism (old) typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split); // Set the number of threads for the backend typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads); @@ -211,6 +219,62 @@ extern "C" { }; typedef struct ggml_backend_feature * (*ggml_backend_get_features_t)(ggml_backend_reg_t reg); + // + // Meta backend + // + + #define GGML_BACKEND_META_MAX_DEVICES 16 + + enum ggml_backend_meta_split_axis { + // tensor split by tensor dimensions: + GGML_BACKEND_SPLIT_AXIS_0 = 0, + GGML_BACKEND_SPLIT_AXIS_1 = 1, + GGML_BACKEND_SPLIT_AXIS_2 = 2, + GGML_BACKEND_SPLIT_AXIS_3 = 3, + + GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends + GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum + + // for internal bookkeeping only: + GGML_BACKEND_SPLIT_AXIS_NONE = 98, + GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99, + }; + GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis); + + struct ggml_backend_meta_split_state { + enum ggml_backend_meta_split_axis axis; + int64_t ne[GGML_BACKEND_META_MAX_DEVICES]; + }; + + // function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible: + typedef struct ggml_backend_meta_split_state (*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata); + + GGML_API bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev); + GGML_API size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev); + GGML_API ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index); + + // create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this: + GGML_API ggml_backend_dev_t ggml_backend_meta_device( + ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud); + + GGML_API bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft); + GGML_API size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft); + GGML_API ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index); + + GGML_API bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf); + GGML_API size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf); + GGML_API ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index); + GGML_API struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index); + + GGML_API bool ggml_backend_is_meta(ggml_backend_t backend); + GGML_API size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend); + GGML_API ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index); + + GGML_API struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync); + + // temporary workaround to statically allocate tensors from a context in a deduplicated way: + GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); + // // Backend registry // diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h index 22ad2c0096..5436c7ef57 100644 --- a/ggml/include/ggml-cuda.h +++ b/ggml/include/ggml-cuda.h @@ -27,6 +27,9 @@ GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend); // device buffer GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); +// conduct allreduce operation between devices +GGML_BACKEND_API bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends); + // split tensor buffer that splits matrices by rows across multiple devices GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split); diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 265023733e..e3f68fc83e 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -200,6 +200,7 @@ add_library(ggml-base ggml.cpp ggml-alloc.c ggml-backend.cpp + ggml-backend-meta.cpp ggml-opt.cpp ggml-threading.cpp ggml-threading.h diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 59190b7c46..baaddb0d6e 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -2,7 +2,9 @@ // ggml-backend internal header +#include "ggml-alloc.h" #include "ggml-backend.h" +#include "ggml.h" #ifdef __cplusplus extern "C" { @@ -49,6 +51,10 @@ extern "C" { void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // (optional) 2d data copies + void (*set_tensor_2d)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + void (*get_tensor_2d)(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported) bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // clear the entire buffer @@ -90,8 +96,10 @@ extern "C" { void (*free)(ggml_backend_t backend); // (optional) asynchronous tensor data access - void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*set_tensor_async) (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async) (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*set_tensor_2d_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + void (*get_tensor_2d_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); // (optional) complete all pending operations (required if the backend supports async operations) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp new file mode 100644 index 0000000000..f37c2e2388 --- /dev/null +++ b/ggml/src/ggml-backend-meta.cpp @@ -0,0 +1,1529 @@ +#include "ggml.h" +#include "ggml-impl.h" +#include "ggml-backend.h" +#include "ggml-backend-impl.h" +#include "ggml-alloc.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct ggml_backend_meta_device; +struct ggml_backend_meta_buffer_type; +struct ggml_backend_meta_buffer; +struct ggml_backend_meta; + +const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) { + switch (split_axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + return "0"; + case GGML_BACKEND_SPLIT_AXIS_1: + return "1"; + case GGML_BACKEND_SPLIT_AXIS_2: + return "2"; + case GGML_BACKEND_SPLIT_AXIS_3: + return "3"; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + return "MIRRORED"; + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: + return "PARTIAL"; + case GGML_BACKEND_SPLIT_AXIS_NONE: + return "NONE"; + case GGML_BACKEND_SPLIT_AXIS_UNKNOWN: + return "UNKNOWN"; + default: + GGML_ABORT("fatal error"); + } +} + +// +// meta backend device +// + +struct ggml_backend_meta_device_context { + std::vector simple_devs; + ggml_backend_meta_get_split_state_t get_split_state; + void * get_split_state_ud; + + std::string name; + std::string description; + + ggml_backend_meta_device_context( + std::vector simple_devs, ggml_backend_meta_get_split_state_t get_splite_state, void * get_split_state_ud) : + simple_devs(std::move(simple_devs)), get_split_state(get_splite_state), get_split_state_ud(get_split_state_ud) { + name = std::string("Meta("); + description = std::string("Meta("); + for (size_t i = 0; i < simple_devs.size(); i++) { + if (i > 0) { + name += ","; + description += ","; + } + name += ggml_backend_dev_name (simple_devs[i]); + description += ggml_backend_dev_description(simple_devs[i]); + } + name += ")"; + description += ")"; + } + + bool operator<(const ggml_backend_meta_device_context & other) const { + return std::tie(simple_devs, get_split_state, get_split_state_ud) + < std::tie(other.simple_devs, other.get_split_state, other.get_split_state_ud); + } +}; + +static const char * ggml_backend_meta_device_get_name(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return meta_dev_ctx->name.c_str(); +} + +static const char * ggml_backend_meta_device_get_description(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return meta_dev_ctx->description.c_str(); +} + +static void ggml_backend_meta_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + *free = 0; + *total = 0; + for (ggml_backend_dev_t dev : meta_dev_ctx->simple_devs) { + size_t tmp_free, tmp_total; + ggml_backend_dev_memory(dev, &tmp_free, &tmp_total); + *free += tmp_free; + *total += tmp_total; + } +} + +static enum ggml_backend_dev_type ggml_backend_meta_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_META; + + GGML_UNUSED(dev); +} + +static void ggml_backend_meta_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + + // TODO replace placeholders + props->name = ggml_backend_meta_device_get_name(dev); + props->description = ggml_backend_meta_device_get_description(dev); + props->type = ggml_backend_meta_device_get_type(dev); + props->device_id = 0; + + ggml_backend_meta_device_get_memory(dev, &props->memory_free, &props->memory_total); + + props->caps = { + /* .async = */ true, + /* .host_buffer = */ false, // Not implemented. + /* .buffer_from_host_ptr = */ false, // Not implemented. + /* .events = */ false, // Not implemented. + }; + for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { + ggml_backend_dev_props tmp_props; + ggml_backend_dev_get_props(simple_dev, &tmp_props); + props->caps.async = props->caps.async && tmp_props.caps.async; + props->caps.host_buffer = props->caps.host_buffer && tmp_props.caps.host_buffer; + props->caps.buffer_from_host_ptr = props->caps.buffer_from_host_ptr && tmp_props.caps.buffer_from_host_ptr; + props->caps.events = props->caps.events && tmp_props.caps.events; + } +} + +static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params); + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev); + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev); + +static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return std::all_of(meta_dev_ctx->simple_devs.begin(), meta_dev_ctx->simple_devs.end(), + [op](ggml_backend_dev_t simple_dev) { return ggml_backend_dev_supports_op(simple_dev, op); }); +} + +static bool ggml_backend_meta_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + ggml_backend_dev_t dev_buft = ggml_backend_buft_get_device(buft); + if (!ggml_backend_dev_is_meta(dev_buft)) { + return false; + } + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + const ggml_backend_meta_device_context * meta_buft_dev_ctx = (const ggml_backend_meta_device_context *) dev_buft->context; + if (meta_dev_ctx->simple_devs.size() != meta_buft_dev_ctx->simple_devs.size()) { + return false; + } + for (size_t i = 0; i < meta_dev_ctx->simple_devs.size(); i++) { + if (meta_dev_ctx->simple_devs[i] != meta_buft_dev_ctx->simple_devs[i]) { + return false; + } + } + return true; +} + +static const ggml_backend_device_i ggml_backend_meta_device_iface = { + /* .get_name = */ ggml_backend_meta_device_get_name, + /* .get_description = */ ggml_backend_meta_device_get_description, + /* .get_memory = */ ggml_backend_meta_device_get_memory, + /* .get_type = */ ggml_backend_meta_device_get_type, + /* .get_props = */ ggml_backend_meta_device_get_props, + /* .init_backend = */ ggml_backend_meta_device_init_backend, + /* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ nullptr, + /* .supports_op = */ ggml_backend_meta_device_supports_op, + /* .supports_buft = */ ggml_backend_meta_device_supports_buft, + /* .offload_op = */ nullptr, + /* .event_new = */ nullptr, + /* .event_free = */ nullptr, + /* .event_synchronize = */ nullptr, +}; + +bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev) { + return dev != nullptr && dev->iface.get_name == ggml_backend_meta_device_iface.get_name; +} + +size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; + return meta_dev_ctx->simple_devs.size(); +} + +ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index) { + GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; + GGML_ASSERT(index < meta_dev_ctx->simple_devs.size()); + return meta_dev_ctx->simple_devs[index]; +} + +ggml_backend_dev_t ggml_backend_meta_device( + ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) { + GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES); + static std::vector> ctxs; + static std::map meta_devs; + + std::vector simple_devs; + simple_devs.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + simple_devs.push_back(devs[i]); + } + ggml_backend_meta_device_context ctx(simple_devs, get_split_state, get_split_state_ud); + + { + auto it = meta_devs.find(ctx); + if (it != meta_devs.end()) { + return &it->second; + } + } + ctxs.push_back(std::make_unique(ctx)); + + struct ggml_backend_device meta_dev = { + /*iface =*/ ggml_backend_meta_device_iface, + /*reg =*/ nullptr, + /*ctx =*/ ctxs.back().get(), + }; + + auto result = meta_devs.emplace(*ctxs.back(), meta_dev); + return &result.first->second; +} + +// +// meta backend buffer type +// + +struct ggml_backend_meta_buffer_type_context { + std::vector simple_bufts; + + std::string name; + + ggml_backend_meta_buffer_type_context(std::vector simple_bufts) : simple_bufts(std::move(simple_bufts)) { + name = "Meta("; + for (size_t i = 0; i < simple_bufts.size(); i++) { + if (i > 0) { + name += ","; + } + name += ggml_backend_buft_name(simple_bufts[i]); + } + name += ")"; + } + + bool operator<(const ggml_backend_meta_buffer_type_context & other) const { + return simple_bufts < other.simple_bufts; + } +}; + +static const char * ggml_backend_meta_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_backend_buft_is_meta(buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) buft->context; + return meta_buft_ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); + +static size_t ggml_backend_meta_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_alignment = 1; + for (size_t i = 0; i < n_simple_bufts; i++) { + const size_t alignment = ggml_backend_buft_get_alignment(ggml_backend_meta_buft_simple_buft(buft, i)); + max_alignment = std::max(max_alignment, alignment); + GGML_ASSERT(max_alignment % alignment == 0); + } + return max_alignment; +} + +static size_t ggml_backend_meta_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_size = SIZE_MAX; + for (size_t i = 0; i < n_simple_bufts; i++) { + max_size = std::min(max_size, ggml_backend_buft_get_max_size(ggml_backend_meta_buft_simple_buft(buft, i))); + } + return max_size; +} + +static size_t ggml_backend_meta_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_alloc_size = 0; + for (size_t i = 0; i < n_simple_bufts; i++) { + const size_t alloc_size = ggml_backend_buft_get_alloc_size(ggml_backend_meta_buft_simple_buft(buft, i), tensor); + max_alloc_size = std::max(max_alloc_size, alloc_size); + } + return max_alloc_size; +} + +static bool ggml_backend_meta_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + for (size_t i = 0; i < n_simple_bufts; i++) { + if (!ggml_backend_buft_is_host(ggml_backend_meta_buft_simple_buft(buft, i))) { + return false; + } + } + return true; +} + +static const struct ggml_backend_buffer_type_i ggml_backend_meta_buffer_type_iface = { + /* .get_name = */ ggml_backend_meta_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_meta_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_meta_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_meta_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_meta_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_meta_buffer_type_is_host, +}; + +bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft) { + return buft != nullptr && buft->iface.get_name == ggml_backend_meta_buffer_type_iface.get_name; +} + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev) { + static std::map meta_bufts; + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + { + auto it = meta_bufts.find(dev); + if (it != meta_bufts.end()) { + return &it->second; + } + } + + const size_t n_devs = ggml_backend_meta_dev_n_devs(dev); + std::vector simple_bufts; + simple_bufts.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + simple_bufts.push_back(ggml_backend_dev_buffer_type(ggml_backend_meta_dev_simple_dev(dev, i))); + } + ggml_backend_meta_buffer_type_context * buft_ctx = new ggml_backend_meta_buffer_type_context(simple_bufts); + + struct ggml_backend_buffer_type meta_buft = { + /*iface =*/ ggml_backend_meta_buffer_type_iface, + /*device =*/ dev, + /*ctx =*/ buft_ctx, + }; + auto result = meta_bufts.emplace(dev, meta_buft); + return &result.first->second; +} + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + + ggml_backend_buffer_type_t host_buft = nullptr; + for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { + ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev); + if (simple_host_buft == nullptr) { + return nullptr; + } + if (host_buft == nullptr) { + host_buft = simple_host_buft; + } else if (host_buft != simple_host_buft) { + // if different simple devices have different host buffer types, + // we cannot provide a single host buffer type for the meta device + return nullptr; + } + } + return host_buft; +} + +size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) { + GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; + return meta_buft_ctx->simple_bufts.size(); +} + +ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index) { + GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; + GGML_ASSERT(index < meta_buft_ctx->simple_bufts.size()); + return meta_buft_ctx->simple_bufts[index]; +} + +// +// meta backend buffer +// + +struct ggml_backend_meta_buffer_context { + std::map, ggml_backend_meta_split_state> split_state_cache; + std::map< const ggml_tensor *, std::vector> simple_tensors; + + struct buffer_config { + ggml_context * ctx; + ggml_backend_buffer_t buf; + + buffer_config(ggml_context * ctx, ggml_backend_buffer_t buf) : ctx(ctx), buf(buf) {} + }; + std::vector buf_configs; + + int debug; + + ggml_backend_meta_buffer_context() { + const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG"); + debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0; + } +}; + +static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + for (auto & [ctx, buf] : buf_ctx->buf_configs) { + ggml_backend_buffer_free(buf); + ggml_free(ctx); + } + delete buf_ctx; +} + +static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_UNUSED(buffer); + return (void *) 0x1000000000000000; // FIXME +} + +static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ true); + GGML_ASSERT(split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + + int split_dim = split_state.axis; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + for (size_t k = 0; k < GGML_MAX_DIMS; k++) { + ne[k] = tensor->ne[k]; + nb[k] = tensor->nb[k]; + } + + std::vector simple_tensors; + simple_tensors.reserve(n_simple_bufs); + for (size_t j = 0; j < n_simple_bufs; j++) { + ggml_context * simple_ctx = buf_ctx->buf_configs[j].ctx; + ggml_backend_buffer_t simple_buf = buf_ctx->buf_configs[j].buf; + + if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { + // TODO: the following assert fails for llama-parallel even though the results are correct: + // GGML_ASSERT(ggml_is_contiguously_allocated(tensor)); + ne[split_dim] = split_state.ne[j]; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->nb[i] > tensor->nb[split_dim]) { + nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim]; + } + } + } + + ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne); + t_ij->op = tensor->op; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + t_ij->nb[i] = nb[i]; + } + t_ij->flags = tensor->flags; + memcpy(t_ij->op_params, tensor->op_params, sizeof(tensor->op_params)); + ggml_set_name(t_ij, tensor->name); + t_ij->buffer = simple_buf; + t_ij->view_offs = tensor->view_offs; + if (split_dim >= 0 && split_dim < GGML_MAX_DIMS && t_ij->view_offs > tensor->nb[split_dim]) { + t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim]; + } + t_ij->view_src = tensor->view_src; + if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) { + t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j); + } + if (t_ij->view_src != nullptr) { + t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs; + } else if (simple_buf != nullptr) { + t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf) + + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(buffer)); + } + t_ij->extra = tensor->extra; + for (int i = 0; i < GGML_MAX_SRC; i++) { + t_ij->src[i] = tensor->src[i]; + if (tensor->src[i] == tensor) { + t_ij->src[i] = t_ij; + } else if (t_ij->src[i] != nullptr && ggml_backend_buffer_is_meta(t_ij->src[i]->buffer)) { + t_ij->src[i] = ggml_backend_meta_buffer_simple_tensor(tensor->src[i], j); + } + } + + simple_tensors.push_back(t_ij); + } + buf_ctx->simple_tensors[tensor] = simple_tensors; + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_bufs; j++){ + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + for (size_t j = 0; j < n_bufs; j++){ + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + ggml_backend_tensor_set(simple_tensor, data, offset, size); + } + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_bufs; j++){ + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + // TODO other simple backend may be better + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); + ggml_backend_tensor_get(simple_tensor, data, offset, size); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); + for (size_t i = 0; i < n_buffers; i++) { + ggml_backend_buffer_clear(ggml_backend_meta_buffer_simple_buffer(buffer, i), value); + } +} + +static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) { + const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); + for (size_t i = 0; i < n_buffers; i++) { + ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i)); + } +} + +static const ggml_backend_buffer_i ggml_backend_meta_buffer_iface = { + /* .free_buffer = */ ggml_backend_meta_buffer_free_buffer, + /* .get_base = */ ggml_backend_meta_buffer_get_base, + /* .init_tensor = */ ggml_backend_meta_buffer_init_tensor, + /* .memset_tensor = */ nullptr, // TODO implement + /* .set_tensor = */ ggml_backend_meta_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_meta_buffer_get_tensor, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_meta_buffer_clear, + /* .reset = */ ggml_backend_meta_buffer_reset, +}; + +bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) { + return buf != nullptr && buf->iface.free_buffer == ggml_backend_meta_buffer_iface.free_buffer; +} + +size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) { + GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; + return buf_ctx->buf_configs.size(); +} + +ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) { + GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; + GGML_ASSERT(index < buf_ctx->buf_configs.size()); + return buf_ctx->buf_configs[index].buf; +} + +struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) { + GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + GGML_ASSERT(index < buf_ctx->buf_configs.size()); + + auto it = buf_ctx->simple_tensors.find(tensor); + if (it == buf_ctx->simple_tensors.end()) { + return nullptr; + } + return it->second[index]; +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + + ggml_init_params params = { + /*.mem_size =*/ 1024*1024*1024, // FIXME + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(); + size_t max_size = 0; + buf_ctx->buf_configs.reserve(n_simple_bufts); + for (size_t i = 0; i < n_simple_bufts; i++) { + ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size); + max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf)); + buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf); + } + + return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size); +} + +struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + + ggml_init_params params = { + /*.mem_size =*/ 1024*1024*1024, // FIXME + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(); + meta_buf_ctx->buf_configs.reserve(n_simple_bufts); + for (size_t i = 0; i < n_simple_bufts; i++) { + meta_buf_ctx->buf_configs.emplace_back(ggml_init(params), nullptr); + } + + ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = meta_buf; + ggml_backend_meta_buffer_init_tensor(meta_buf, t); + t->data = (void *) 0x2000000000000000; // FIXME + } + for (size_t i = 0; i < n_simple_bufts; i++) { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft( + meta_buf_ctx->buf_configs[i].ctx, ggml_backend_meta_buft_simple_buft(buft, i)); + meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf)); + } + return meta_buf; +} + +// +// meta backend +// + +static ggml_guid_t ggml_backend_meta_guid() { + static ggml_guid guid = {0xf1, 0x0e, 0x34, 0xcf, 0x9c, 0x6f, 0x43, 0xcb, 0x96, 0x92, 0xbe, 0x8e, 0xbb, 0x71, 0x3f, 0xda}; + return &guid; +} + +struct ggml_backend_meta_context { + struct cgraph_config { + ggml_cgraph cgraph_main; + int offset; // Node offset vs. original graph, only used for debugging. + + std::vector cgraphs_aux; + std::vector nodes_aux; + + cgraph_config(ggml_cgraph cgraph_main, int offset) : cgraph_main(cgraph_main), offset(offset) {} + }; + struct backend_config { + ggml_backend_t backend; + + std::vector cgraphs; + std::vector nodes; + ggml_context * ctx = nullptr; + std::vector bufs; // Multiple buffers to reduce synchronizations. + + backend_config(ggml_backend_t backend) : backend(backend) {} + + ~backend_config() { + for (ggml_backend_buffer_t buf : bufs) { + ggml_backend_buffer_free(buf); + } + ggml_free(ctx); + } + }; + std::string name; + std::vector backend_configs; + size_t max_tmp_size = 0; + size_t max_subgraphs = 0; + + ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { + const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); + name = "Meta("; + backend_configs.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i); + if (i > 0) { + name += ","; + } + name += ggml_backend_dev_name(simple_dev); + backend_configs.emplace_back(ggml_backend_dev_init(simple_dev, params)); + } + name += ")"; + } + + ~ggml_backend_meta_context() { + for (auto & bc : backend_configs) { + ggml_backend_free(bc.backend); + } + } + + size_t n_reduce_steps() const { + return std::ceil(std::log2(backend_configs.size())); + } + + ggml_tensor * get_next_tensor(size_t j, std::vector & tensors, ggml_tensor * node) { + ggml_tensor * next = tensors[j] == nullptr ? ggml_get_first_tensor(backend_configs[j].ctx) + : ggml_get_next_tensor(backend_configs[j].ctx, tensors[j]); + if (next == nullptr) { + next = ggml_new_tensor_1d(backend_configs[j].ctx, GGML_TYPE_F32, 1); + } + memset(next, 0, sizeof(ggml_tensor)); + next->op = GGML_OP_NONE; + next->type = node->type; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + next->ne[dim] = node->ne[dim]; + next->nb[dim] = node->nb[dim]; + } + tensors[j] = next; + return next; + } +}; + +static const char * ggml_backend_meta_get_name(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_meta(backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) backend->context; + return backend_ctx->name.c_str(); +} + +static void ggml_backend_meta_free(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_meta(backend)); + ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + delete backend_ctx; + delete backend; +} + +static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_backends; j++){ + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, + i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + for (size_t j = 0; j < n_backends; j++) { + ggml_backend_tensor_set_async( + ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size); + } + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_backends; j++){ + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j, + i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + // TODO other simple backend may be better + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); + ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +static void ggml_backend_meta_synchronize(ggml_backend_t backend) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + for (size_t i = 0; i < n_backends; i++) { + ggml_backend_synchronize(ggml_backend_meta_simple_backend(backend, i)); + } +} + +static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.cgraphs.clear(); + bcj.nodes.clear(); + bcj.nodes.reserve(cgraph->n_nodes*n_reduce_steps); + + for (int i = 0; i < cgraph->n_nodes; i++) { + bcj.nodes.push_back(ggml_backend_meta_buffer_simple_tensor(cgraph->nodes[i], j)); + GGML_ASSERT(bcj.nodes[i]); + } + } + + size_t n_subgraphs = 0; + size_t max_tmp_size = 0; + { + int i_start = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); + } + const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; + if (!new_subgraph) { + continue; + } + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.cgraphs.emplace_back(*cgraph, i_start); + bcj.cgraphs.back().cgraph_main.nodes = bcj.nodes.data() + i_start; + bcj.cgraphs.back().cgraph_main.n_nodes = i + 1 - i_start; + } + n_subgraphs++; + i_start = i + 1; + } + GGML_ASSERT(i_start == cgraph->n_nodes); + } + + if (max_tmp_size > backend_ctx->max_tmp_size) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (ggml_backend_buffer_t buf : bcj.bufs) { + ggml_backend_buffer_free(buf); + } + bcj.bufs.clear(); + for (size_t k = 0; k < n_reduce_steps + 1; k++) { + bcj.bufs.push_back(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } + } + backend_ctx->max_tmp_size = max_tmp_size; + } + if (n_subgraphs > backend_ctx->max_subgraphs) { + ggml_init_params params = { + /*.mem_size =*/ n_subgraphs*n_reduce_steps*2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_free(bcj.ctx); + bcj.ctx = ggml_init(params); + } + backend_ctx->max_subgraphs = n_subgraphs; + } + + size_t i_buf = 0; // Alternate between tmp buffers per simple backend to reduce synchronizations. + std::vector tensors(n_backends, nullptr); + + // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable: + auto allreduce_fallback = [&](size_t i) -> ggml_status { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + + bcj.cgraphs[i].cgraphs_aux.clear(); + bcj.cgraphs[i].cgraphs_aux.reserve(n_reduce_steps); + bcj.cgraphs[i].nodes_aux.clear(); + bcj.cgraphs[i].nodes_aux.reserve(n_reduce_steps*2); + } + + for (size_t offset_j = 1; offset_j < n_backends; offset_j *= 2) { + for (size_t j = 0; j < n_backends; j++) { + const size_t j_other = j ^ offset_j; + if (j_other > j) { + continue; + } + + auto & bcj1 = backend_ctx->backend_configs[j]; + auto & bcj2 = backend_ctx->backend_configs[j_other]; + + ggml_tensor * node1 = bcj1.cgraphs[i].cgraph_main.nodes[bcj1.cgraphs[i].cgraph_main.n_nodes-1]; + ggml_tensor * node2 = bcj2.cgraphs[i].cgraph_main.nodes[bcj2.cgraphs[i].cgraph_main.n_nodes-1]; + GGML_ASSERT(ggml_is_contiguous(node1)); + GGML_ASSERT(ggml_is_contiguous(node2)); + + ggml_tensor * node_tmp_1 = backend_ctx->get_next_tensor(j, tensors, node1); + ggml_tensor * node_tmp_2 = backend_ctx->get_next_tensor(j_other, tensors, node2); + node_tmp_1->buffer = bcj1.bufs[i_buf]; + node_tmp_2->buffer = bcj2.bufs[i_buf]; + node_tmp_1->data = ggml_backend_buffer_get_base(bcj1.bufs[i_buf]); + node_tmp_2->data = ggml_backend_buffer_get_base(bcj2.bufs[i_buf]); + bcj1.cgraphs[i].nodes_aux.push_back(node_tmp_1); + bcj2.cgraphs[i].nodes_aux.push_back(node_tmp_2); + + ggml_backend_tensor_copy_async(bcj1.backend, bcj2.backend, node1, node_tmp_2); + ggml_backend_tensor_copy_async(bcj2.backend, bcj1.backend, node2, node_tmp_1); + + ggml_tensor * node_red_1 = backend_ctx->get_next_tensor(j, tensors, node1); + ggml_tensor * node_red_2 = backend_ctx->get_next_tensor(j_other, tensors, node2); + node_red_1->view_src = node1->view_src == nullptr ? node1 : node1->view_src; + node_red_2->view_src = node2->view_src == nullptr ? node2 : node2->view_src; + node_red_1->view_offs = node1->view_offs; + node_red_2->view_offs = node2->view_offs; + node_red_1->op = GGML_OP_ADD; + node_red_2->op = GGML_OP_ADD; + node_red_1->src[0] = node1; + node_red_2->src[0] = node2; + node_red_1->src[1] = node_tmp_1; + node_red_2->src[1] = node_tmp_2; + node_red_1->flags |= GGML_TENSOR_FLAG_COMPUTE; + node_red_2->flags |= GGML_TENSOR_FLAG_COMPUTE; + ggml_backend_view_init(node_red_1); + ggml_backend_view_init(node_red_2); + bcj1.cgraphs[i].nodes_aux.push_back(node_red_1); + bcj2.cgraphs[i].nodes_aux.push_back(node_red_2); + + bcj1.cgraphs[i].cgraphs_aux.push_back(*cgraph); + bcj2.cgraphs[i].cgraphs_aux.push_back(*cgraph); + bcj1.cgraphs[i].cgraphs_aux.back().nodes = &bcj1.cgraphs[i].nodes_aux.back(); + bcj2.cgraphs[i].cgraphs_aux.back().nodes = &bcj2.cgraphs[i].nodes_aux.back(); + bcj1.cgraphs[i].cgraphs_aux.back().n_nodes = 1; + bcj2.cgraphs[i].cgraphs_aux.back().n_nodes = 1; + + i_buf = (i_buf + 1) % (n_reduce_steps + 1); + } + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, &bcj.cgraphs[i].cgraphs_aux.back()); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + } + return GGML_STATUS_SUCCESS; + }; + + + for (size_t i = 0; i < n_subgraphs; i++) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, &bcj.cgraphs[i].cgraph_main); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + + if (n_backends > 1 && i < n_subgraphs - 1) { + bool backend_allreduce_success = false; + ggml_backend_allreduce_tensor_t allreduce_tensor = (ggml_backend_allreduce_tensor_t) ggml_backend_reg_get_proc_address( + ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_ctx->backend_configs[0].backend)), "ggml_backend_allreduce_tensor"); + if (allreduce_tensor) { + std::vector backends; + backends.reserve(n_backends); + std::vector nodes; + nodes.reserve(n_backends); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + backends.push_back(bcj.backend); + nodes.push_back(bcj.cgraphs[i].cgraph_main.nodes[bcj.cgraphs[i].cgraph_main.n_nodes-1]); + } + backend_allreduce_success = allreduce_tensor(backends.data(), nodes.data(), n_backends); + } + + if (!backend_allreduce_success) { + const ggml_status status = allreduce_fallback(i); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + } + } + return GGML_STATUS_SUCCESS; +} + +static const ggml_backend_i ggml_backend_meta_i = { + /* .get_name = */ ggml_backend_meta_get_name, + /* .free = */ ggml_backend_meta_free, + /* .set_tensor_async = */ ggml_backend_meta_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_meta_get_tensor_async, + /* .get_tensor_2d_async = */ nullptr, + /* .set_tensor_2d_async = */ nullptr, + /* .cpy_tensor_async = */ nullptr, + /* .synchronize = */ ggml_backend_meta_synchronize, + /* .graph_plan_create = */ nullptr, + /* .graph_plan_free = */ nullptr, + /* .graph_plan_update = */ nullptr, + /* .graph_plan_compute = */ nullptr, + /* .graph_compute = */ ggml_backend_meta_graph_compute, + /* .event_record = */ nullptr, + /* .event_wait = */ nullptr, + /* .graph_optimize = */ nullptr, +}; + +bool ggml_backend_is_meta(ggml_backend_t backend) { + return backend != nullptr && backend->iface.get_name == ggml_backend_meta_i.get_name; +} + +static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params) { + ggml_backend_meta_context * backend_ctx = new ggml_backend_meta_context(dev, params); + + ggml_backend_t backend = new struct ggml_backend; + backend->guid = ggml_backend_meta_guid(); + backend->iface = ggml_backend_meta_i; + backend->device = dev; + backend->context = backend_ctx; + return backend; +} + +size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend) { + GGML_ASSERT(ggml_backend_is_meta(meta_backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; + return backend_ctx->backend_configs.size(); +} + +ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index) { + GGML_ASSERT(ggml_backend_is_meta(meta_backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; + return backend_ctx->backend_configs[index].backend; +} + +struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + + auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool { + if (a.axis != b.axis) { + return false; + } + for (size_t j = 0; j < n_bufs; j++) { + if (a.ne[j] != b.ne[j]) { + return false; + } + } + return true; + }; + + auto handle_generic = [&](const std::vector & src_split_states, bool scalar_only) -> ggml_backend_meta_split_state { + ggml_backend_meta_split_state homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}}; + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { + continue; + } + if (homogeneous_src_split_state.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { + homogeneous_src_split_state = src_split_states[i]; + } else if (!split_states_equal(src_split_states[i], homogeneous_src_split_state)) { + homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; + break; + } + } + if (homogeneous_src_split_state.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { + homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; + } + if (scalar_only && homogeneous_src_split_state.axis >= 0 && homogeneous_src_split_state.axis < GGML_MAX_DIMS) { + homogeneous_src_split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; + } + GGML_ASSERT(homogeneous_src_split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + return homogeneous_src_split_state; + }; + + // Some ops process data on a per-row bases: + auto handle_per_row = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_split_states[0].axis != GGML_BACKEND_SPLIT_AXIS_0); + return src_split_states[0]; + }; + + // Some ops broadcast the src1 data across src0: + auto handle_bin_bcast = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { + if (src_split_states[0].axis >= 0 && src_split_states[0].axis < GGML_MAX_DIMS && + tensor->src[1]->ne[src_split_states[0].axis] == 1 && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_split_states[0]; + } + if (src_split_states[0].axis == src_split_states[1].axis && src_split_states[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_split_states[0]; // GGML_ADD_ID + } + GGML_ASSERT(tensor->src[2] == nullptr || src_split_states[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + return handle_generic(src_split_states, /*scalar_only =*/ false); + }; + + auto handle_mul_mat = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { + if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}}; + } + if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + ggml_backend_meta_split_state ret = src_split_states[0]; + ret.axis = GGML_BACKEND_SPLIT_AXIS_0; + return ret; + } + if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { + for (size_t j = 0; j < n_bufs; j++) { + GGML_ASSERT(src_split_states[0].ne[j] == src_split_states[1].ne[j]); + } + return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}}; + } + GGML_ABORT("fatal error"); + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; + }; + + auto handle_reshape = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { + switch (src_split_states[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { + GGML_ASSERT(ggml_is_contiguous(tensor)); + int64_t base_ne_in = 1; + for (int dim = 0; dim <= src_split_states[0].axis; dim++) { + base_ne_in *= tensor->src[0]->ne[dim]; + } + int64_t base_ne_out = 1; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim]; + if (base_ne_out_next == base_ne_in) { + return {ggml_backend_meta_split_axis(dim), {0}}; + } + if (base_ne_out_next > base_ne_in) { + GGML_ASSERT(dim + 1 < GGML_MAX_DIMS); + return {ggml_backend_meta_split_axis(dim + 1), {0}}; + } + base_ne_out = base_ne_out_next; + } + GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op)); + } + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_split_states[0]; + } + default: { + GGML_ABORT("fatal error"); + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; + } + } + }; + + auto handle_view = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { + if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->view_src)) { + return handle_reshape(src_split_states); + } + if (src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + return src_split_states[0]; + } + GGML_ABORT("view of permuted tensor not implemented"); + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; + }; + + auto handle_permute = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { + switch (src_split_states[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { + return {ggml_backend_meta_split_axis(tensor->op_params[src_split_states[0].axis]), {0}}; + } + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_split_states[0]; + } + default: { + GGML_ABORT("fatal error"); + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; + } + } + }; + + auto handle_set_rows = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_split_states[0].axis != GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(split_states_equal(src_split_states[0], src_split_states[2])); + return src_split_states[0]; + }; + + auto handle_rope = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + return src_split_states[0]; + }; + + auto handle_flash_attn_ext = [&](const std::vector & src_split_states) -> ggml_backend_meta_split_state { + GGML_ASSERT( src_split_states[0].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_split_states[1].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_split_states[2].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(tensor->src[4] == nullptr || src_split_states[4].axis == GGML_BACKEND_SPLIT_AXIS_0); + return {GGML_BACKEND_SPLIT_AXIS_1, {0}}; + }; + + auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { + if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) { + ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer)); + const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud); + if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) { + const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1; + int64_t ne_sum = 0; + for (size_t j = 0; j < n_bufs; j++) { + GGML_ASSERT(ret.ne[j] % granularity == 0); + ne_sum += ret.ne[j]; + } + GGML_ASSERT(ne_sum == tensor->ne[ret.axis]); + } + return ret; + } + + std::vector src_split_states(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}}); + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { + src_split_states[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; + continue; + } + src_split_states[i] = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); + } + + ggml_backend_meta_split_state split_state; + switch (tensor->op) { + case GGML_OP_NONE: { + split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}}; + } break; + case GGML_OP_DUP: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_ADD: + case GGML_OP_ADD_ID: { + split_state = handle_bin_bcast(src_split_states); + } break; + case GGML_OP_ADD1: + case GGML_OP_ACC: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: { + split_state = handle_bin_bcast(src_split_states); + } break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_SIN: + case GGML_OP_COS: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; + case GGML_OP_SUM: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: { + split_state = handle_per_row(src_split_states); + } break; + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_CONCAT: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_SILU_BACK: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_GROUP_NORM: + case GGML_OP_L2_NORM: { + split_state = handle_per_row(src_split_states); + } break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: { + split_state = handle_mul_mat(src_split_states); + } break; + case GGML_OP_OUT_PROD: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_SCALE: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; + case GGML_OP_SET: + case GGML_OP_CPY: + case GGML_OP_CONT: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_RESHAPE: { + split_state = handle_reshape(src_split_states); + } break; + case GGML_OP_VIEW: { + split_state = handle_view(src_split_states); + } break; + case GGML_OP_PERMUTE: { + split_state = handle_permute(src_split_states); + } break; + case GGML_OP_TRANSPOSE: + case GGML_OP_GET_ROWS: + case GGML_OP_GET_ROWS_BACK: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_SET_ROWS: { + split_state = handle_set_rows(src_split_states); + } break; + case GGML_OP_DIAG: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_DIAG_MASK_ZERO: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; + case GGML_OP_ROPE: { + split_state = handle_rope(src_split_states); + } break; + case GGML_OP_ROPE_BACK: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_CLAMP: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_BACK: + case GGML_OP_IM2COL_3D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_POOL_1D: + case GGML_OP_POOL_2D: + case GGML_OP_POOL_2D_BACK: + case GGML_OP_UPSCALE: + case GGML_OP_PAD: + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_ROLL: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: { + split_state = handle_per_row(src_split_states); + } break; + case GGML_OP_LEAKY_RELU: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; + case GGML_OP_TRI: + case GGML_OP_FILL: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_FLASH_ATTN_EXT: { + split_state = handle_flash_attn_ext(src_split_states); + } break; + case GGML_OP_FLASH_ATTN_BACK: + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: + case GGML_OP_ADD_REL_POS: + case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_RWKV_WKV7: + case GGML_OP_SOLVE_TRI: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_UNARY: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; + case GGML_OP_MAP_CUSTOM1: + case GGML_OP_MAP_CUSTOM2: + case GGML_OP_MAP_CUSTOM3: + case GGML_OP_CUSTOM: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ true); + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { + split_state = handle_per_row(src_split_states); + } break; + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + case GGML_OP_GLU: { + split_state = handle_generic(src_split_states, /*scalar_only =*/ false); + } break; + default: { + GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op)); + split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}}; + } break; + } + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + bool src_split_by_axis_found = false; + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || src_split_states[i].axis < 0 || src_split_states[i].axis >= GGML_MAX_DIMS) { + continue; + } + if (src_split_by_axis_found) { + for (size_t j = 0; j < n_bufs; j++) { + // Assert that ratio is consistent: + GGML_ASSERT( split_state.ne[j] * tensor->src[i]->ne[src_split_states[i].axis] + == src_split_states[i].ne[j] * tensor->ne[split_state.axis]); + } + } else { + for (size_t j = 0; j < n_bufs; j++) { + // Take over ratio from src: + split_state.ne[j] = src_split_states[i].ne[j] * tensor->ne[split_state.axis]; + GGML_ASSERT(split_state.ne[j] % tensor->src[i]->ne[src_split_states[i].axis] == 0); + split_state.ne[j] /= tensor->src[i]->ne[src_split_states[i].axis]; + } + } + src_split_by_axis_found = true; + } + GGML_ASSERT(src_split_by_axis_found); + } + return split_state; + }; + + const std::pair key = std::make_pair(tensor, assume_sync); + + if (buf_ctx->split_state_cache.find(key) == buf_ctx->split_state_cache.end()) { + buf_ctx->split_state_cache[key] = calculate_split_state(); + if (buf_ctx->debug > 0) { + std::string srcs_info; + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr) { + continue; + } + if (!srcs_info.empty()) { + srcs_info += ", "; + } + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true); + const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis); + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + ne_info += std::to_string(split_state.ne[j]); + } + srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]"; + } + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + ne_info += std::to_string(buf_ctx->split_state_cache[key].ne[j]); + } + GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op), + ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].axis), ne_info.c_str()); + } + } + + ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key]; + GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE && ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + return ret; +} diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 22c656996c..1a555bf2a4 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -123,7 +123,7 @@ size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { GGML_ASSERT(buffer); // get_base is optional if the buffer is zero-sized - if (buffer->size == 0) { + if (!ggml_backend_buffer_is_meta(buffer) && buffer->size == 0) { return NULL; } @@ -279,15 +279,57 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten } } +void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_set_async(backend, tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + backend->iface.set_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + +void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_get_async(backend, tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + backend->iface.get_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); if (size == 0) { return; } - GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); @@ -297,18 +339,62 @@ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, siz void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); if (size == 0) { return; } - GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); buf->iface.get_tensor(buf, tensor, data, offset, size); } +void ggml_backend_tensor_set_2d(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + + if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_set(tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + buf->iface.set_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + +void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + + if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_get(tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + + buf->iface.get_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; @@ -388,7 +474,7 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { // backend copy -void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { +void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); if (src == dst) { @@ -402,7 +488,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } else if (!ggml_backend_buffer_copy_tensor(src, dst)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer)); -#endif +#endif // NDEBUG size_t nbytes = ggml_nbytes(src); void * data = malloc(nbytes); ggml_backend_tensor_get(src, data, 0, nbytes); @@ -411,7 +497,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } } -void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) { +void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); if (src == dst) { @@ -500,6 +586,7 @@ enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) { } void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) { + GGML_ASSERT(device); memset(props, 0, sizeof(*props)); device->iface.get_props(device, props); } @@ -610,6 +697,8 @@ static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = { /* .memset_tensor = */ NULL, /* .set_tensor = */ NULL, /* .get_tensor = */ NULL, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_multi_buffer_clear, /* .reset = */ NULL, @@ -1899,8 +1988,9 @@ enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct GGML_ASSERT(tensor->data == NULL); GGML_ASSERT(tensor->view_src == NULL); GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer)); - GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= - (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer) || + (char *) addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= + (char *) ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); tensor->buffer = buffer; tensor->data = addr; @@ -2174,6 +2264,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .clear = */ ggml_backend_cpu_buffer_clear, /* .reset = */ NULL, @@ -2186,6 +2278,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .clear = */ ggml_backend_cpu_buffer_clear, /* .reset = */ NULL, diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 2e9ddf2240..0bf295677e 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -260,6 +260,8 @@ static struct ggml_backend_i blas_backend_i = { /* .get_name = */ ggml_backend_blas_get_name, /* .free = */ ggml_backend_blas_free, /* .set_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .get_tensor_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 3f3de9f0bc..c9837b19b7 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1355,6 +1355,8 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor, /* .clear = */ ggml_backend_cann_buffer_clear, /* .reset = */ NULL, @@ -2567,6 +2569,8 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .free = */ ggml_backend_cann_free, /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async, /* .synchronize = */ ggml_backend_cann_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 895a571375..791b051cb2 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -111,6 +111,8 @@ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, /* .cpy_tensor = */ nullptr, /* .clear = */ ggml_backend_amx_buffer_clear, /* .reset = */ nullptr, diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index ddf1737a31..49f840be20 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -195,6 +195,8 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .free = */ ggml_backend_cpu_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 262f88204e..1fd965076d 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -182,6 +182,16 @@ if (CUDAToolkit_FOUND) target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver) endif() + if (GGML_CUDA_NCCL) + find_package(NCCL) + if (NCCL_FOUND) + add_compile_definitions(GGML_USE_NCCL) + target_link_libraries(ggml-cuda PRIVATE NCCL::NCCL) + else() + message(STATUS "Warning: NCCL not found, performance for multiple CUDA GPUs will be suboptimal") + endif() + endif() + set(CUDA_CXX_FLAGS "") set(CUDA_FLAGS -use_fast_math -extended-lambda) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a3256d59dd..f9f85b30dc 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -186,6 +186,10 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) +#ifdef GGML_USE_NCCL +#define NCCL_CHECK(err) CUDA_CHECK_GEN(err, ncclSuccess, ncclGetErrorString) +#endif // GGML_USE_NCCL + #if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) static const char * cu_get_error_str(CUresult err) { const char * err_str; @@ -1050,6 +1054,10 @@ struct ggml_cuda_device_info { cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; std::array default_tensor_split = {}; + +#ifdef GGML_USE_NCCL + ncclComm_t comms[GGML_CUDA_MAX_DEVICES]; +#endif // GGML_USE_NCCL }; const ggml_cuda_device_info & ggml_cuda_info(); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index bed5c71a1b..2d1028d5af 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -309,6 +309,28 @@ static ggml_cuda_device_info ggml_cuda_init() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + for (int id = 0; id < info.device_count; ++id) { + ggml_cuda_set_device(id); + for (int id_other = 0; id_other < info.device_count; ++id_other) { + if (id == id_other) { + continue; + } + int can_access_peer; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); + if (can_access_peer) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + } + } + } + +#ifdef GGML_USE_NCCL + int dev_ids[GGML_CUDA_MAX_DEVICES]; + for (int id = 0; id < info.device_count; ++id) { + dev_ids[id] = id; + } + NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids)); +#endif // GGML_USE_NCCL + return info; } @@ -617,26 +639,46 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer } static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread)); + CUDA_CHECK(cudaMemsetAsync((char *) tensor->data + offset, value, size, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpy2DAsync( + (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CUDA_CHECK(cudaMemcpy2DAsync( + data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } @@ -676,6 +718,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, + /* .set_tensor_2d = */ ggml_backend_cuda_buffer_set_tensor_2d, + /* .get_tensor_2d = */ ggml_backend_cuda_buffer_get_tensor_2d, /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor, /* .clear = */ ggml_backend_cuda_buffer_clear, /* .reset = */ NULL, @@ -988,6 +1032,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_cuda_split_buffer_clear, /* .reset = */ NULL, @@ -1064,6 +1110,37 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; +bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends) { +#ifdef GGML_USE_NCCL + const ggml_cuda_device_info info = ggml_cuda_info(); + + const size_t ne = ggml_nelements(tensors[0]); + + NCCL_CHECK(ncclGroupStart()); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream())); + } + NCCL_CHECK(ncclGroupEnd()); + + return true; +#else + // If NCCL is installed it is used by default for optimal performance. + // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package. + // RCCL is disabled by default, users are explicitly opting in. + // Therefore print no warning for RCCL. +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + static bool warning_printed = false; + if (!warning_printed) { + GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__); + warning_printed = true; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + GGML_UNUSED_VARS(backends, tensors, n_backends); + return false; +#endif // GGML_USE_NCCL +} + ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) { static std::mutex mutex; std::lock_guard lock(mutex); @@ -1371,64 +1448,6 @@ static void ggml_cuda_op_mul_mat_cublas( GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size); } -static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { - static bool peer_access_enabled = false; - - const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; - - if (peer_access_enabled == enable_peer_access) { - return; - } - -#ifdef NDEBUG - for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - ggml_cuda_set_device(id); - CUDA_CHECK(cudaDeviceSynchronize()); - } - - for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - ggml_cuda_set_device(id); - - for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) { - if (id == id_other) { - continue; - } - if (id != main_device && id_other != main_device) { - continue; - } - - int can_access_peer; - CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); - if (can_access_peer) { - if (enable_peer_access) { - cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0); - if (err != cudaErrorPeerAccessAlreadyEnabled) { - CUDA_CHECK(err); - } else { - // reset the error - (void)cudaGetLastError(); - } - } else { - cudaError_t err = cudaDeviceDisablePeerAccess(id_other); - if (err != cudaErrorPeerAccessNotEnabled) { - CUDA_CHECK(err); - } else { - // reset the error - (void)cudaGetLastError(); - } - } - } - } - } - - ggml_cuda_set_device(main_device); -#endif // NDEBUG - - peer_access_enabled = enable_peer_access; - - GGML_UNUSED(main_device); -} - static cudaError_t ggml_cuda_Memcpy2DPeerAsync( void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { @@ -2420,11 +2439,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { - // why is this here instead of mul_mat? - if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) { - ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); - } - switch (dst->op) { case GGML_OP_ARGMAX: ggml_cuda_argmax(ctx, dst); @@ -2779,21 +2793,43 @@ static void ggml_backend_cuda_free(ggml_backend_t backend) { } static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); + CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); } static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); + CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_set_tensor_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpy2DAsync( + (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_get_tensor_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpy2DAsync( + data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cuda_ctx->stream())); } static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { @@ -2804,21 +2840,21 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ return false; } - if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) { return false; } // device -> device copy - ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context; - ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context; + ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *) backend_src->context; + ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *) backend_dst->context; - ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; - ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; + ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context; + ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context; if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); -#endif +#endif // NDEBUG return false; } @@ -2831,7 +2867,7 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ return false; #else CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream())); -#endif +#endif // GGML_CUDA_NO_PEER_COPY } // record event on src stream after the copy @@ -4254,6 +4290,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .free = */ ggml_backend_cuda_free, /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, + /* .get_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async, + /* .set_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async, /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async, /* .synchronize = */ ggml_backend_cuda_synchronize, /* .graph_plan_create = */ NULL, @@ -5033,6 +5071,9 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { GGML_UNUSED(reg); + if (strcmp(name, "ggml_backend_allreduce_tensor") == 0) { + return (void *)ggml_backend_cuda_allreduce_tensor; + } if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { return (void *)ggml_backend_cuda_split_buffer_type; } diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index ba032cfab4..1500e1b95f 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -6,6 +6,10 @@ #include #include +#ifdef GGML_USE_NCCL +#include +#endif // GGML_USE_NCCL + #if CUDART_VERSION >= 12050 #include #endif // CUDART_VERSION >= 12050 diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 5cc1b54319..4f6c059b81 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -10,6 +10,11 @@ #include #endif // defined(GGML_HIP_ROCWMMA_FATTN) +#ifdef GGML_USE_NCCL +#include +#endif // GGML_USE_NCCL + + #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N @@ -28,6 +33,7 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define NCCL_CHECK(fn) {ncclResult_t err = fn; if(err != ncclSuccess) { GGML_ABORT("RCCL Failure RCCL returned: %i\n", err); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 54f9986498..2843be6db1 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1455,6 +1455,8 @@ static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_hexagon_buffer_set_tensor, /* .get_tensor = */ ggml_backend_hexagon_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_hexagon_buffer_cpy_tensor, /* .clear = */ ggml_backend_hexagon_buffer_clear, /* .reset = */ NULL, @@ -2841,6 +2843,8 @@ static struct ggml_backend_i hexagon_backend_i = { /* .free = */ ggml_backend_hexagon_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_hexagon_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 80037d2436..4a1564865c 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -43,6 +43,10 @@ find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) +if (GGML_HIP_RCCL) + find_package(rccl REQUIRED) +endif() + if (${hip_VERSION} VERSION_LESS 6.1) message(FATAL_ERROR "At least ROCM/HIP V6.1 is required") endif() @@ -118,6 +122,10 @@ if (NOT GGML_HIP_MMQ_MFMA) add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) endif() +if (GGML_HIP_RCCL) + add_compile_definitions(GGML_USE_NCCL) # RCCL has the same interface as NCCL. +endif() + if (GGML_HIP_EXPORT_METRICS) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") endif() @@ -137,4 +145,8 @@ if (GGML_STATIC) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") endif() +if (GGML_HIP_RCCL) + target_link_libraries(ggml-hip PRIVATE ggml-base roc::rccl) +endif() + target_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas) diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 1c705362fb..477dcff042 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -90,6 +90,8 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, /* .clear = */ ggml_backend_metal_buffer_shared_clear, /* .reset = */ NULL, @@ -563,6 +565,8 @@ static ggml_backend_i ggml_backend_metal_i = { /* .free = */ ggml_backend_metal_free, /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups /* .synchronize = */ ggml_backend_metal_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index ae3f79fd0d..6c2a3c7337 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -3642,6 +3642,8 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .synchronize = */ ggml_backend_opencl_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, @@ -4999,6 +5001,8 @@ static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor, /* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_opencl_buffer_clear, /* .reset = */ ggml_backend_opencl_buffer_reset, diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index d7c8ad8c16..00452b2360 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -705,6 +705,8 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, /* .clear = */ ggml_backend_rpc_buffer_clear, /* .reset = */ NULL, @@ -893,6 +895,8 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, /* .cpy_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .synchronize = */ ggml_backend_rpc_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 0614d7e8f3..25e15a9bf7 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -589,6 +589,8 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, /* .clear = */ ggml_backend_sycl_buffer_clear, /* .reset = */ ggml_backend_sycl_buffer_reset, @@ -4455,6 +4457,8 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .free = */ ggml_backend_sycl_free, /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async, /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async, // // TODO: update for the new // interface diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp index 6b95362dd8..b6c561cd61 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp @@ -101,6 +101,8 @@ const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor, /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, /* .clear = */ ggml_backend_remoting_buffer_clear, /* .reset = */ NULL, @@ -113,6 +115,8 @@ const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor_from_ptr, /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor_from_ptr, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, /* .clear = */ ggml_backend_remoting_buffer_clear, /* .reset = */ NULL, diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp index 5cd6c0c060..f22ce4113d 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp @@ -34,6 +34,8 @@ static ggml_backend_i ggml_backend_remoting_interface = { /* .free = */ ggml_backend_remoting_free, /* .set_tensor_async = */ NULL, // ggml_backend_remoting_set_tensor_async, /* .get_tensor_async = */ NULL, // ggml_backend_remoting_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_remoting_cpy_tensor_async, /* .synchronize = */ NULL, // ggml_backend_remoting_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 114992da08..3ed5e21be8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -13091,6 +13091,8 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, /* .clear = */ ggml_backend_vk_buffer_clear, /* .reset = */ NULL, @@ -14392,6 +14394,8 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .free = */ ggml_backend_vk_free, /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async, /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, /* .synchronize = */ ggml_backend_vk_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 32e120266a..311bc1cbde 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2197,6 +2197,8 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* .free = */ ggml_backend_webgpu_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, @@ -2362,6 +2364,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, // TODO: optional, implement this /* .clear = */ ggml_backend_webgpu_buffer_clear, /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index 9b6938abf7..e6b6fc24fd 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -313,6 +313,8 @@ static ggml_backend_buffer_i ggml_backend_zdnn_buffer_i = { /* .memset_tensor = */ ggml_backend_zdnn_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_zdnn_buffer_set_tensor, /* .get_tensor = */ ggml_backend_zdnn_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_zdnn_buffer_clear, /* .reset = */ NULL, @@ -417,20 +419,22 @@ static enum ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend, } static ggml_backend_i ggml_backend_zdnn_i = { - /* .get_name = */ ggml_backend_zdnn_name, - /* .free = */ ggml_backend_zdnn_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_zdnn_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, - /* .graph_optimize = */ NULL, + /* .get_name = */ ggml_backend_zdnn_name, + /* .free = */ ggml_backend_zdnn_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_zdnn_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_zdnn_guid(void) { diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 551c15bb4a..346450e603 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -240,6 +240,8 @@ static struct ggml_backend_i ggml_backend_zendnn_i = { /* .free = */ ggml_backend_zendnn_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, diff --git a/include/llama.h b/include/llama.h index d2d7f59ebc..dda15e8a4f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -189,9 +189,10 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); enum llama_split_mode { - LLAMA_SPLIT_MODE_NONE = 0, // single GPU - LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported + LLAMA_SPLIT_MODE_TENSOR = 3, }; // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fc05989aa5..87947ddac9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -975,9 +975,11 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void for (auto & backend : backends) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); - auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); - if (set_abort_callback_fn) { - set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); + if (reg) { + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); + } } } } diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index cb702b2a59..5920aa9a1d 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -187,7 +187,11 @@ llama_kv_cache::llama_kv_cache( t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it } } else { - buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); // real buffer + if (ggml_backend_buft_is_meta(buft)) { + buf = ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx.get(), buft); + } else { + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); // real buffer + } } if (!buf) { throw std::runtime_error("failed to allocate buffer for kv cache"); diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index f0038036dc..9d040da4b1 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -1,5 +1,6 @@ #include "llama-memory-recurrent.h" +#include "ggml-backend.h" #include "llama-impl.h" #include "llama-io.h" #include "llama-batch.h" @@ -101,7 +102,8 @@ llama_memory_recurrent::llama_memory_recurrent( // allocate tensors and initialize the buffers to avoid NaNs in the padding for (auto & [buft, ctx] : ctx_map) { - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); + ggml_backend_buffer_t buf = ggml_backend_buft_is_meta(buft) ? + ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx.get(), buft) : ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); if (!buf) { throw std::runtime_error("failed to allocate buffer for rs cache"); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c26584aa67..c10ca7f6c6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -18,14 +18,148 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) { + const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata; + + const std::regex pattern_q_weight("blk\\.\\d*\\.attn_q.weight"); + const std::regex pattern_kv_weight("blk\\.\\d*\\.attn_(k|v).weight"); + const std::regex pattern_q_bias("blk\\.\\d*\\.attn_q\\.bias"); + const std::regex pattern_kv_bias("blk\\.\\d*\\.attn_(k|v)\\.bias"); + const std::regex pattern_qk_norm("blk\\.\\d*\\.attn_(q|k)_norm\\.weight"); + const std::regex pattern_kv_cache("cache_(k|v)_l\\d*"); + const std::regex pattern_attn_sinks("blk\\.\\d*\\.attn_sinks.weight"); + const std::regex pattern_attn_out_weight("blk\\.\\d*\\.attn_output.weight"); + const std::regex pattern_attn_out_bias("blk\\.\\d*\\.attn_output.bias"); + const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight"); + const std::regex pattern_ffn_up_gate_bias("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias"); + const std::regex pattern_ffn_down_weight("blk\\.\\d*\\.ffn_down(_exps)?.weight"); + const std::regex pattern_ffn_down_bias("blk\\.\\d*\\.ffn_down(_exps)?.bias"); + const std::regex pattern_output_weight("output\\.weight"); + const std::regex pattern_output_bias("output\\.bias"); + + auto get_split_axis = [&]() -> ggml_backend_meta_split_axis { + // attention + if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_kv_weight)) { + return GGML_BACKEND_SPLIT_AXIS_1; + } + if (std::regex_match(tensor->name, pattern_q_bias) || std::regex_match(tensor->name, pattern_kv_bias)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + if (std::regex_match(tensor->name, pattern_qk_norm)) { + return tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1; + } + if (std::regex_match(tensor->name, pattern_kv_cache) || std::regex_match(tensor->name, pattern_attn_sinks)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + if (std::regex_match(tensor->name, pattern_attn_out_weight)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + if (std::regex_match(tensor->name, pattern_attn_out_bias)) { + return GGML_BACKEND_SPLIT_AXIS_MIRRORED; + } + + // FFN + if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight)) { + return GGML_BACKEND_SPLIT_AXIS_1; + } + if (std::regex_match(tensor->name, pattern_ffn_up_gate_bias)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + if (std::regex_match(tensor->name, pattern_ffn_down_weight)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + if (std::regex_match(tensor->name, pattern_ffn_down_bias)) { + return GGML_BACKEND_SPLIT_AXIS_MIRRORED; + } + + // output + if (std::regex_match(tensor->name, pattern_output_weight)) { + return GGML_BACKEND_SPLIT_AXIS_1; + } + if (std::regex_match(tensor->name, pattern_output_bias)) { + return GGML_BACKEND_SPLIT_AXIS_0; + } + + // everything else + return GGML_BACKEND_SPLIT_AXIS_MIRRORED; + }; + + auto get_split_granularity = [&](ggml_backend_meta_split_axis split_axis) -> int64_t { + const int64_t blck_size = split_axis == GGML_BACKEND_SPLIT_AXIS_1 && tensor->ne[1] % 256 == 0 ? 256 : 32; + + // attention + if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_q_bias) || + std::regex_match(tensor->name, pattern_attn_out_weight)) { + const uint32_t n_gqa = ud->model->hparams.n_gqa(); + const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k; + return std::lcm(n_embd_q, blck_size); + } + if (std::regex_match(tensor->name, pattern_kv_weight) || std::regex_match(tensor->name, pattern_kv_bias) || + std::regex_match(tensor->name, pattern_kv_cache)) { + const uint32_t n_gqa = ud->model->hparams.n_gqa(); + const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k; + return std::lcm(n_embd_q, blck_size) / n_gqa; + } + if (std::regex_match(tensor->name, pattern_attn_sinks)) { + const uint32_t n_gqa = ud->model->hparams.n_gqa(); + const uint32_t n_embd_q = n_gqa * ud->model->hparams.n_embd_head_k; + return std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa; + } + + // FFN + if (std::regex_match(tensor->name, pattern_ffn_up_gate_weight) || std::regex_match(tensor->name, pattern_ffn_up_gate_bias) || + std::regex_match(tensor->name, pattern_ffn_down_weight)) { + return blck_size; + } + + // everything else + return 1; + }; + + ggml_backend_meta_split_state split_state; + split_state.axis = get_split_axis(); + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + const int64_t ne_full = tensor->ne[split_state.axis]; + const int64_t granularity = get_split_granularity(split_state.axis); + GGML_ASSERT(ne_full % granularity == 0); + const float * tensor_split = ud->model->tensor_split(); + std::vector tensor_split_scan; + tensor_split_scan.reserve(ud->n_devices); + for (size_t j = 0; j < ud->n_devices; j++) { + tensor_split_scan.push_back(tensor_split[j]); + if (j > 0) { + tensor_split_scan[j] += tensor_split_scan[j - 1]; + } + } + int64_t low = 0; + size_t j = 0; + for (; j < ud->n_devices - 1; j++) { + int64_t high = tensor_split_scan.back() == 0.0f ? + ne_full * (j+1)/ud->n_devices : ne_full * tensor_split_scan[j]/tensor_split_scan.back(); + if (high % granularity != 0) { + high -= high % granularity; + } + split_state.ne[j] = high - low; + low = high; + } + split_state.ne[j] = ne_full - low; + } else { + memset(split_state.ne, 0, sizeof(split_state.ne)); + } + return split_state; + GGML_UNUSED(userdata); +} + const char * llm_type_name(llm_type type) { switch (type) { case LLM_TYPE_14M: return "14M"; @@ -420,14 +554,16 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s // add the device extra buffer type (if any) ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts"); + if (reg) { + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts"); - if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev); - while (extra_bufts && *extra_bufts) { - buft_list.emplace_back(dev, *extra_bufts); - ++extra_bufts; + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(dev, *extra_bufts); + ++extra_bufts; + } } } @@ -7649,7 +7785,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { t->buffer = buf; // set dummy buffer for weights so that the backend scheduler won't try to allocate them } } else { - buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer + if (ggml_backend_buft_is_meta(buft)) { + buf = ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer + } else { + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer + } } if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); @@ -7751,6 +7891,10 @@ size_t llama_model::n_devices() const { return devices.size(); } +const float * llama_model::tensor_split() const { + return params.tensor_split; +} + uint32_t llama_model::n_gpu_layers() const { return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; } diff --git a/src/llama-model.h b/src/llama-model.h index b350591429..1aafd017ed 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -446,6 +446,13 @@ struct llama_layer { struct llama_layer_nextn nextn; }; +struct llama_meta_device_get_split_state_userdata { + size_t n_devices; + const struct llama_model * model; +}; + +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata); + struct llama_model { llm_type type = LLM_TYPE_UNKNOWN; llm_arch arch = LLM_ARCH_UNKNOWN; @@ -506,6 +513,9 @@ struct llama_model { // for keeping track of associated LoRA adapters std::unordered_set loras; + // statically allocated context for assigning + struct llama_meta_device_get_split_state_userdata get_split_state_ud; + int64_t t_load_us = 0; int64_t t_start_us = 0; @@ -526,6 +536,7 @@ struct llama_model { size_t size() const; // file size size_t n_tensors() const; size_t n_devices() const; + const float * tensor_split() const; uint32_t n_gpu_layers() const; llama_split_mode split_mode() const; diff --git a/src/llama.cpp b/src/llama.cpp index 6da90d6f1f..1ea326107b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21,7 +21,9 @@ #include #include #include +#include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -160,6 +162,9 @@ static void llama_params_fit_impl( const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { + if (mparams->split_mode == LLAMA_SPLIT_MODE_TENSOR) { + throw llama_params_fit_exception("llama_params_fit is not implemented for SPLIT_MODE_TENSOR, abort"); + } constexpr int64_t MiB = 1024*1024; typedef std::vector dmds_t; const llama_model_params default_mparams = llama_model_default_params(); @@ -911,8 +916,19 @@ static struct llama_model * llama_model_load_from_file_impl( // create list of devices to use with this model if (params.devices) { - for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { - model->devices.push_back(*dev); + if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { + size_t n_devs = 0; + while (params.devices[n_devs]) { + n_devs++; + } + model->get_split_state_ud.n_devices = n_devs; + model->get_split_state_ud.model = model; + model->devices.push_back(ggml_backend_meta_device( + params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud)); + } else { + for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { + model->devices.push_back(*dev); + } } } else { // default device selection @@ -922,47 +938,64 @@ static struct llama_model * llama_model_load_from_file_impl( std::vector igpus; std::vector rpc_servers; - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t dev = ggml_backend_dev_get(i); - switch (ggml_backend_dev_type(dev)) { - case GGML_BACKEND_DEVICE_TYPE_CPU: - case GGML_BACKEND_DEVICE_TYPE_ACCEL: - // skip CPU backends since they are handled separately - break; + if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { + std::vector devs; + devs.reserve(ggml_backend_dev_count()); + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + devs.push_back(ggml_backend_dev_get(i)); + } + GGML_ASSERT(devs.size() >= 2); + GGML_ASSERT(ggml_backend_dev_buffer_type(devs.back()) == ggml_backend_cpu_buffer_type()); + model->get_split_state_ud.n_devices = devs.size() - 1; + model->get_split_state_ud.model = model; + gpus.push_back(ggml_backend_meta_device( + devs.data(), devs.size() - 1, llama_meta_device_get_split_state, &model->get_split_state_ud)); + } else { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + switch (ggml_backend_dev_type(dev)) { + case GGML_BACKEND_DEVICE_TYPE_CPU: + case GGML_BACKEND_DEVICE_TYPE_ACCEL: + // skip CPU backends since they are handled separately + break; - case GGML_BACKEND_DEVICE_TYPE_GPU: { - ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - if (ggml_backend_reg_name(reg) == std::string("RPC")) { - rpc_servers.push_back(dev); - } else { - // check if there is already a GPU with the same device id - ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); - auto it = std::find_if(gpus.begin(), gpus.end(), [&props](ggml_backend_dev_t d) { - ggml_backend_dev_props d_props; - ggml_backend_dev_get_props(d, &d_props); - if (props.device_id && d_props.device_id) { - return strcmp(props.device_id, d_props.device_id) == 0; - } - return false; - }); - - if (it != gpus.end()) { - LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", - __func__, - ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), - props.device_id ? props.device_id : "unknown id", - ggml_backend_dev_name(*it), ggml_backend_dev_description(*it)); + case GGML_BACKEND_DEVICE_TYPE_GPU: { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + if (ggml_backend_reg_name(reg) == std::string("RPC")) { + rpc_servers.push_back(dev); } else { - gpus.push_back(dev); - } - } - break; - } + // check if there is already a GPU with the same device id + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + auto it = std::find_if(gpus.begin(), gpus.end(), [&props](ggml_backend_dev_t d) { + ggml_backend_dev_props d_props; + ggml_backend_dev_get_props(d, &d_props); + if (props.device_id && d_props.device_id) { + return strcmp(props.device_id, d_props.device_id) == 0; + } + return false; + }); - case GGML_BACKEND_DEVICE_TYPE_IGPU: - igpus.push_back(dev); - break; + if (it != gpus.end()) { + LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", + __func__, + ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + props.device_id ? props.device_id : "unknown id", + ggml_backend_dev_name(*it), ggml_backend_dev_description(*it)); + } else { + gpus.push_back(dev); + } + } + break; + } + + case GGML_BACKEND_DEVICE_TYPE_IGPU: + igpus.push_back(dev); + break; + case GGML_BACKEND_DEVICE_TYPE_META: + GGML_ABORT("fatal error"); + break; + } } } diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 7da6c3957c..1d4c5e5922 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -259,6 +259,8 @@ static const char * split_mode_str(llama_split_mode mode) { return "layer"; case LLAMA_SPLIT_MODE_ROW: return "row"; + case LLAMA_SPLIT_MODE_TENSOR: + return "tensor"; default: GGML_ABORT("invalid split mode"); } @@ -440,7 +442,7 @@ static void print_usage(int /* argc */, char ** argv) { join(cmd_params_defaults.n_gpu_layers, ",").c_str()); printf(" -ncmoe, --n-cpu-moe (default: %s)\n", join(cmd_params_defaults.n_cpu_moe, ",").c_str()); - printf(" -sm, --split-mode (default: %s)\n", + printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); @@ -723,6 +725,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { mode = LLAMA_SPLIT_MODE_LAYER; } else if (m == "row") { mode = LLAMA_SPLIT_MODE_ROW; + } else if (m == "tensor") { + mode = LLAMA_SPLIT_MODE_TENSOR; } else { invalid_param = true; break; @@ -1685,7 +1689,7 @@ struct markdown_printer : public printer { return 6; } if (field == "split_mode") { - return 5; + return 6; } if (field == "flash_attn") { return 2;