first commit naive test to enable mmq for RDNA4

This commit is contained in:
jiachengjason 2025-08-17 15:16:09 +00:00 committed by jiachengjason
parent 96ac5a2329
commit 9f87b491bd
6 changed files with 182 additions and 142 deletions

View File

@ -210,6 +210,7 @@ option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental,
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
option(GGML_HIP_MMQ_WMMA "ggml: enable WMMA MMA for RDNA4 in MMQ" ON)
option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF)

View File

@ -232,6 +232,9 @@ static const char * cu_get_error_str(CUresult err) {
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#define VOLTA_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#if defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_MMQ_WMMA)
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(RDNA4) && !defined(GGML_HIP_NO_MMQ_WMMA)
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define TURING_MMA_AVAILABLE
@ -295,6 +298,15 @@ static bool volta_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
}
static bool amd_wmma_available(const int cc) {
#if !defined(GGML_HIP_NO_MMQ_WMMA)
return GGML_CUDA_CC_IS_RDNA4(cc);
#else
return false;
#endif //!defined(GGML_HIP_NO_MMQ_WMMA)
}
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool turing_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}

View File

@ -425,7 +425,7 @@ namespace ggml_cuda_mma {
template <int I, int J, typename T>
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
#if defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
@ -798,6 +798,22 @@ namespace ggml_cuda_mma {
acc[0],
0, 0, 0);
#endif // defined(CDNA3)
#elif defined(AMD_WMMA_AVAILABLE)
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
int32x16_t * acc = (int32x16_t *) D.x;
#if defined(RDNA4)
acc[0] = __builtin_amdgcn_wmma_i32_32x32x16_i8(A.x[0],
B.x[0],
acc[0],
0, 0, 0);
acc[0] = __builtin_amdgcn_wmma_i32_32x32x16_i8(A.x[1],
B.x[1],
acc[0],
0, 0, 0);
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;

View File

@ -290,11 +290,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
if (amd_mfma_available(cc)) {
if (amd_mfma_available(cc)||amd_wmma_available(cc)) {
// As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)
// performs better but is currently suffering from a crash on this architecture.
// TODO: Revisit when hipblaslt is fixed on CDNA3
if (GGML_CUDA_CC_IS_CDNA3(cc)) {
if (GGML_CUDA_CC_IS_CDNA3(cc)||GGML_CUDA_CC_IS_RDNA4(cc)) {
return true;
}
if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {

File diff suppressed because it is too large Load Diff

View File

@ -116,6 +116,14 @@ if (NOT GGML_HIP_MMQ_MFMA)
add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
endif()
if (NOT GGML_HIP_MMQ_WMMA)
add_compile_definitions(GGML_HIP_NO_MMQ_WMMA)
endif()
if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
endif()
if (GGML_HIP_EXPORT_METRICS)
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
endif()