From d132f22fc92f36848f7ccf2fc9987cd0b0120825 Mon Sep 17 00:00:00 2001 From: andyluo7 <43718156+andyluo7@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:13:32 +0300 Subject: [PATCH] HIP: add CDNA4 (gfx950) architecture support for MI350X/MI355X (#21570) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add AMD Instinct MI350X/MI355X (gfx950, CDNA4) support: - vendors/hip.h: Add CDNA4 preprocessor define for __gfx950__ - common.cuh: Add GGML_CUDA_CC_CDNA4 and GGML_CUDA_CC_IS_CDNA4 macros - mma.cuh: Route CDNA4 to compatible MFMA instructions: * f32 matmul: mfma_f32_16x16x4f32 (xf32 variant unavailable on gfx950) * bf16 matmul: mfma_f32_16x16x16bf16_1k (same as CDNA3) * int8 matmul: mfma_i32_16x16x32_i8/32x32x16 (same as CDNA3) - mmq.cuh: Include CDNA4 in stream-k kernel dispatch CDNA4 is largely compatible with CDNA3 except: - No xf32 MFMA (mfma_f32_16x16x8_xf32) — routes to f32 path - Different FP8 format (e4m3fn vs e4m3_fnuz) — not changed here Tested on AMD Instinct MI355X (gfx950), ROCm 7.0.1: - Build: compiles cleanly with -DAMDGPU_TARGETS=gfx950 - llama-bench (Qwen2.5-1.5B Q4_K_M, single GPU): * f16+FA: 40,013 tok/s prefill, 254 tok/s decode * q8_0+FA: functional - Flash attention: works correctly - MMQ: works correctly with stream-k dispatch Co-authored-by: Andy Luo --- ggml/src/ggml-cuda/common.cuh | 4 +++- ggml/src/ggml-cuda/mma.cuh | 17 +++++++++-------- ggml/src/ggml-cuda/mmq.cuh | 2 +- ggml/src/ggml-cuda/vendors/hip.h | 8 ++++++-- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 64b91811c3..56a67f1edc 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -67,6 +67,7 @@ #define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x90a) // MI210 (gfx90a), minimum acc register renaming #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 +#define GGML_CUDA_CC_CDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x950) // MI350X/MI355X // RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 @@ -87,7 +88,8 @@ #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) #define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2) #define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3) -#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_CDNA4) +#define GGML_CUDA_CC_IS_CDNA4(cc) (cc >= GGML_CUDA_CC_CDNA4 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads #define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 5d1dadd3e4..c91dd2d9ad 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1025,7 +1025,8 @@ namespace ggml_cuda_mma { const floatx2_t& a_frag = reinterpret_cast(A.x[0]); const floatx2_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA1) +#elif defined(CDNA4) || defined(CDNA2) || defined(CDNA1) + // CDNA4 (gfx950) does not support xf32 MFMA, use f32 path like CDNA2/CDNA1 #pragma unroll for (int i = 0; i < 2; ++i) { acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0); @@ -1187,7 +1188,7 @@ namespace ggml_cuda_mma { #elif defined(AMD_MFMA_AVAILABLE) using floatx4_t = __attribute__((ext_vector_type(4))) float; floatx4_t& acc_frag = reinterpret_cast(D.x[0]); -#if defined(CDNA3) || defined(CDNA2) +#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; const bf16x4_t& a_frag = reinterpret_cast(A.x[0]); const bf16x4_t& b_frag = reinterpret_cast(B.x[0]); @@ -1216,12 +1217,12 @@ namespace ggml_cuda_mma { #if defined(AMD_MFMA_AVAILABLE) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * acc = (int32x4_t *) D.x; -#if defined(CDNA3) +#if defined(CDNA4) || defined(CDNA3) acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA) +#elif defined(CDNA2) || defined(CDNA1) acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], @@ -1230,7 +1231,7 @@ namespace ggml_cuda_mma { B.x[1], acc[0], 0, 0, 0); -#endif // defined(CDNA3) +#endif // defined(CDNA4) || defined(CDNA3) #elif defined(AMD_WMMA_AVAILABLE) @@ -1295,12 +1296,12 @@ namespace ggml_cuda_mma { #if defined(AMD_MFMA_AVAILABLE) using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; int32x16_t * acc = (int32x16_t *) D.x; -#if defined(CDNA3) +#if defined(CDNA4) || defined(CDNA3) acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA) +#elif defined(CDNA2) || defined(CDNA1) acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], @@ -1309,7 +1310,7 @@ namespace ggml_cuda_mma { B.x[1], acc[0], 0, 0, 0); -#endif // defined(CDNA3) +#endif // defined(CDNA4) || defined(CDNA3) #else GGML_UNUSED_VARS(D, A, B); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 489d3616bb..1891114147 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3645,7 +3645,7 @@ static __global__ void mul_mat_q( tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; } -#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA +#endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA constexpr int ITER_K = get_iter_k(type); diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index d146e018d9..898fec31e3 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -189,6 +189,10 @@ #define GCN #endif // defined(GCN5) || defined(GCN4) +#if defined(__gfx950__) +#define CDNA4 +#endif // defined(__gfx950__) + #if defined(__gfx942__) #define CDNA3 #endif // defined(__gfx942__) @@ -201,9 +205,9 @@ #define CDNA1 #endif // defined(__gfx908__) -#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #define CDNA // For the entire family -#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#endif // defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #if defined(__GFX12__) #define RDNA4