From 074b93146e410aac4de51c197a75814e62671154 Mon Sep 17 00:00:00 2001 From: zhang hui Date: Sat, 13 Dec 2025 14:10:56 +0800 Subject: [PATCH] mma for rdna3 --- ggml/src/ggml-cuda/mma.cuh | 53 +++++++------------------------------- 1 file changed, 9 insertions(+), 44 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 16f28f6ab9..c4016a49eb 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -295,12 +295,7 @@ namespace ggml_cuda_mma { } } #elif defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA3) - // RDNA3 has duplicated data as input. - static constexpr int ne = I * J / 32 * 2; -#else static constexpr int ne = I * J / 32; -#endif // defined(RDNA3) half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -319,14 +314,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { -#if defined(RDNA4) return 4 * (threadIdx.x / 16) + l; -#elif defined(RDNA3) - return l; -#else - NO_DEVICE_CODE; - return -1; -#endif // defined(RDNA4) } else { NO_DEVICE_CODE; return -1; @@ -384,42 +372,19 @@ namespace ggml_cuda_mma { static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA3) - // RDNA3 has duplicated data as input. - static constexpr int ne = I * J / 32 * 2; -#else static constexpr int ne = I * J / 32; -#endif // defined(RDNA3) nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 16 && J == 8) return true; - return false; + return tile::supported(); } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 16 && J == 8) { - return threadIdx.x % 16; - } else { - NO_DEVICE_CODE; - return -1; - } + return tile::get_i(l); } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 16 && J == 8) { -#if defined(RDNA4) - return 4 * (threadIdx.x / 16) + l; -#elif defined(RDNA3) - return l; -#else - NO_DEVICE_CODE; - return -1; -#endif // defined(RDNA4) - } else { - NO_DEVICE_CODE; - return -1; - } + return tile::get_j(l); } #else static constexpr int ne = I * J / WARP_SIZE; @@ -931,9 +896,9 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } - template + template static __device__ __forceinline__ void mma( - tile<16, 8, float, DLayout> & D, const tile<16, 8, float, ABLayout> & A, const tile<8, 8, float, ABLayout> & B) { + tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) { #ifdef AMPERE_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -987,9 +952,9 @@ namespace ggml_cuda_mma { #endif // AMPERE_MMA_AVAILABLE } - template + template static __device__ __forceinline__ void mma( - tile<16, 16, float, DLayout> & D, const tile<16, 8, half2, ABLayout> & A, const tile<16, 8, half2, ABLayout> & B) { + tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) { #ifdef TURING_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -1041,9 +1006,9 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } - template + template static __device__ __forceinline__ void mma( - tile<16, 16, float, DLayout> & D, const tile<16, 8, nv_bfloat162, ABLayout> & A, const tile<16, 8, nv_bfloat162, ABLayout> & B) { + tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) { #if defined(AMD_WMMA_AVAILABLE) #if defined(RDNA4) using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;