diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 2e976f3fc1..93644e70dc 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -212,6 +212,7 @@ set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balan option(GGML_HIP "ggml: use HIP" OFF) option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index be96f3acaf..f9f85b30dc 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1055,7 +1055,9 @@ struct ggml_cuda_device_info { std::array default_tensor_split = {}; +#ifdef GGML_USE_NCCL ncclComm_t comms[GGML_CUDA_MAX_DEVICES]; +#endif // GGML_USE_NCCL }; const ggml_cuda_device_info & ggml_cuda_info(); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3abf34f979..a90e370560 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -323,6 +323,7 @@ static ggml_cuda_device_info ggml_cuda_init() { } } +#ifdef GGML_USE_NCCL int dev_ids[GGML_CUDA_MAX_DEVICES]; for (int id = 0; id < info.device_count; ++id) { dev_ids[id] = id; @@ -330,6 +331,7 @@ static ggml_cuda_device_info ggml_cuda_init() { NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids)); return info; +#endif // GGML_USE_NCCL } const ggml_cuda_device_info & ggml_cuda_info() { @@ -1099,6 +1101,10 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t return true; #else + // If NCCL is installed it is used by default for optimal performance. + // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package. + // RCCL is disabled by default, users are explicitly opting in. + // Therefore print no warning for RCCL. #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static bool warning_printed = false; if (!warning_printed) { diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 5cc1b54319..4f6c059b81 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -10,6 +10,11 @@ #include #endif // defined(GGML_HIP_ROCWMMA_FATTN) +#ifdef GGML_USE_NCCL +#include +#endif // GGML_USE_NCCL + + #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N @@ -28,6 +33,7 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define NCCL_CHECK(fn) {ncclResult_t err = fn; if(err != ncclSuccess) { GGML_ABORT("RCCL Failure RCCL returned: %i\n", err); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 80037d2436..4a1564865c 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -43,6 +43,10 @@ find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) +if (GGML_HIP_RCCL) + find_package(rccl REQUIRED) +endif() + if (${hip_VERSION} VERSION_LESS 6.1) message(FATAL_ERROR "At least ROCM/HIP V6.1 is required") endif() @@ -118,6 +122,10 @@ if (NOT GGML_HIP_MMQ_MFMA) add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) endif() +if (GGML_HIP_RCCL) + add_compile_definitions(GGML_USE_NCCL) # RCCL has the same interface as NCCL. +endif() + if (GGML_HIP_EXPORT_METRICS) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") endif() @@ -137,4 +145,8 @@ if (GGML_STATIC) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") endif() +if (GGML_HIP_RCCL) + target_link_libraries(ggml-hip PRIVATE ggml-base roc::rccl) +endif() + target_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas)