GGML: HIP: add RCCL support
This commit is contained in:
parent
8de41b5b40
commit
29c5327d01
|
|
@ -212,6 +212,7 @@ set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balan
|
|||
|
||||
option(GGML_HIP "ggml: use HIP" OFF)
|
||||
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
||||
option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF)
|
||||
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
||||
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
|
||||
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
|
||||
|
|
|
|||
|
|
@ -1055,7 +1055,9 @@ struct ggml_cuda_device_info {
|
|||
|
||||
std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
|
||||
|
||||
#ifdef GGML_USE_NCCL
|
||||
ncclComm_t comms[GGML_CUDA_MAX_DEVICES];
|
||||
#endif // GGML_USE_NCCL
|
||||
};
|
||||
|
||||
const ggml_cuda_device_info & ggml_cuda_info();
|
||||
|
|
|
|||
|
|
@ -323,6 +323,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_NCCL
|
||||
int dev_ids[GGML_CUDA_MAX_DEVICES];
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
dev_ids[id] = id;
|
||||
|
|
@ -330,6 +331,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids));
|
||||
|
||||
return info;
|
||||
#endif // GGML_USE_NCCL
|
||||
}
|
||||
|
||||
const ggml_cuda_device_info & ggml_cuda_info() {
|
||||
|
|
@ -1099,6 +1101,10 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t
|
|||
|
||||
return true;
|
||||
#else
|
||||
// If NCCL is installed it is used by default for optimal performance.
|
||||
// However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package.
|
||||
// RCCL is disabled by default, users are explicitly opting in.
|
||||
// Therefore print no warning for RCCL.
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
static bool warning_printed = false;
|
||||
if (!warning_printed) {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,11 @@
|
|||
#include <rocwmma/rocwmma-version.hpp>
|
||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
|
||||
#ifdef GGML_USE_NCCL
|
||||
#include <rccl/rccl.h>
|
||||
#endif // GGML_USE_NCCL
|
||||
|
||||
|
||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||
|
|
@ -28,6 +33,7 @@
|
|||
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||
#define NCCL_CHECK(fn) {ncclResult_t err = fn; if(err != ncclSuccess) { GGML_ABORT("RCCL Failure RCCL returned: %i\n", err); }}
|
||||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
||||
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@ find_package(hip REQUIRED)
|
|||
find_package(hipblas REQUIRED)
|
||||
find_package(rocblas REQUIRED)
|
||||
|
||||
if (GGML_HIP_RCCL)
|
||||
find_package(rccl REQUIRED)
|
||||
endif()
|
||||
|
||||
if (${hip_VERSION} VERSION_LESS 6.1)
|
||||
message(FATAL_ERROR "At least ROCM/HIP V6.1 is required")
|
||||
endif()
|
||||
|
|
@ -118,6 +122,10 @@ if (NOT GGML_HIP_MMQ_MFMA)
|
|||
add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
|
||||
endif()
|
||||
|
||||
if (GGML_HIP_RCCL)
|
||||
add_compile_definitions(GGML_USE_NCCL) # RCCL has the same interface as NCCL.
|
||||
endif()
|
||||
|
||||
if (GGML_HIP_EXPORT_METRICS)
|
||||
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
|
||||
endif()
|
||||
|
|
@ -137,4 +145,8 @@ if (GGML_STATIC)
|
|||
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
|
||||
endif()
|
||||
|
||||
if (GGML_HIP_RCCL)
|
||||
target_link_libraries(ggml-hip PRIVATE ggml-base roc::rccl)
|
||||
endif()
|
||||
|
||||
target_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas)
|
||||
|
|
|
|||
Loading…
Reference in New Issue