From a31225cc0338b17e27b644e9ac53c57c5142c286 Mon Sep 17 00:00:00 2001 From: zhang hui Date: Thu, 15 Jan 2026 16:07:15 +0800 Subject: [PATCH] pass cdna compile --- ggml/src/ggml-cuda/mma.cuh | 58 ++++++++++++++++++++++++++++++++++++-- ggml/src/ggml-cuda/mmf.cu | 4 +-- ggml/src/ggml-cuda/mmf.cuh | 18 +++++++++--- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 30c08b6ffb..3efda0021d 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -339,6 +339,32 @@ namespace ggml_cuda_mma { return -1; } } +#elif defined(AMD_MFMA_AVAILABLE) + static constexpr int ne = I * J / 64; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return 2 * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } #else static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -391,7 +417,22 @@ namespace ggml_cuda_mma { static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if defined(AMD_WMMA_AVAILABLE) - static constexpr int ne = I * J / 32; + static constexpr int ne = tile::ne; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile::get_j(l); + } +#elif defined(AMD_MFMA_AVAILABLE) + static constexpr int ne = tile::ne; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -1090,12 +1131,25 @@ namespace ggml_cuda_mma { NO_DEVICE_CODE; #endif // defined(RDNA4) #elif defined(AMD_MFMA_AVAILABLE) - using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; using floatx4_t = __attribute__((ext_vector_type(4))) float; floatx4_t& acc_frag = reinterpret_cast(D.x[0]); +#if defined(CDNA3) || defined(CDNA2) + using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; const bf16x4_t& a_frag = reinterpret_cast(A.x[0]); const bf16x4_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0); +#elif defined(CDNA1) +#pragma unroll + for (int i = 0; i < 2; ++i) { + using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16; + const bf16x2_t& a_frag = reinterpret_cast(A.x[i]); + const bf16x2_t& b_frag = reinterpret_cast(B.x[i]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0); + } +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(CDNA3) || defined(CDNA2) #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 96fd3c5847..4e03d0db08 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -198,9 +198,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const case GGML_TYPE_F32: return ampere_mma_available(cc); case GGML_TYPE_F16: - return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); + return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc); case GGML_TYPE_BF16: - return ampere_mma_available(cc) || amd_wmma_available(cc); + return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc); default: return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index d4a0bee322..e1bffa0b3a 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -30,12 +30,17 @@ static __global__ void mul_mat_f( const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added -#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T, get_input_data_layout()> tile_A; typedef tile<16, 8, T, get_input_data_layout()> tile_B; typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#elif defined(AMD_MFMA_AVAILABLE) + if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { @@ -261,7 +266,7 @@ static __global__ void mul_mat_f( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); NO_DEVICE_CODE; -#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } //This kernel is for larger batch sizes of mul_mat_id @@ -276,12 +281,17 @@ static __global__ void mul_mat_f_ids( const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const uint3 sis1_fd, const uint3 nch_fd) { // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added -#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T, get_input_data_layout()> tile_A; typedef tile<16, 8, T, get_input_data_layout()> tile_B; typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#elif defined(AMD_MFMA_AVAILABLE) + if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { @@ -532,7 +542,7 @@ static __global__ void mul_mat_f_ids( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); NO_DEVICE_CODE; -#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } template