HIP: RDNA4 tensor core support for MMF (#17077)

* mmf for rdna4

* align the padding for rdna4

* forbit mul_mat_f for rdna4

* fix as comment

* remove device kernels

* add constexpr for early return

* update based on review comment

* change based on the review comment

* pass compile error

* keep code consistency

---------

Co-authored-by: zhang hui <you@example.com>
This commit is contained in:
yulo 2025-11-22 07:03:24 +08:00 committed by GitHub
parent 8e9ddba610
commit 028f93ef98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 180 additions and 23 deletions

View File

@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) {
#define AMD_MFMA_AVAILABLE #define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#if defined(GGML_USE_HIP) && defined(RDNA4)
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(RDNA4)
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable: // The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#define VOLTA_MMA_AVAILABLE #define VOLTA_MMA_AVAILABLE
@ -283,6 +287,10 @@ static bool amd_mfma_available(const int cc) {
#endif //!defined(GGML_HIP_NO_MMQ_MFMA) #endif //!defined(GGML_HIP_NO_MMQ_MFMA)
} }
static bool amd_wmma_available(const int cc) {
return GGML_CUDA_CC_IS_RDNA4(cc);
}
static bool volta_mma_available(const int cc) { static bool volta_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA; return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
} }

View File

@ -39,6 +39,15 @@ template<typename dst_t, typename src_t>
return __float2bfloat16(float(x)); return __float2bfloat16(float(x));
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) { } else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
return __bfloat162float(x); return __bfloat162float(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
return __float22half2_rn(x);
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
// bypass compile error on cuda 12.0.1
#ifdef GGML_USE_HIP
return __float22bfloat162_rn(x);
#else
return {x.x, x.y};
#endif // GGML_USE_HIP
} else if constexpr(std::is_same_v<dst_t, int32_t>) { } else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x); return int32_t(x);
} else { } else {

View File

@ -74,6 +74,33 @@ namespace ggml_cuda_mma {
static constexpr int J = J_; static constexpr int J = J_;
#if defined(GGML_USE_HIP) #if defined(GGML_USE_HIP)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#else
static constexpr int ne = I * J / 64; static constexpr int ne = I * J / 64;
T x[ne] = {0}; T x[ne] = {0};
@ -119,6 +146,7 @@ namespace ggml_cuda_mma {
return -1; return -1;
} }
} }
#endif // defined(RDNA4)
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
static constexpr int ne = I * J / 32; static constexpr int ne = I * J / 32;
T x[ne] = {0}; T x[ne] = {0};
@ -236,6 +264,32 @@ namespace ggml_cuda_mma {
return -1; return -1;
} }
} }
#elif defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#else #else
static constexpr int ne = I * J / WARP_SIZE; static constexpr int ne = I * J / WARP_SIZE;
half2 x[ne] = {{0.0f, 0.0f}}; half2 x[ne] = {{0.0f, 0.0f}};
@ -285,6 +339,34 @@ namespace ggml_cuda_mma {
struct tile<I_, J_, nv_bfloat162> { struct tile<I_, J_, nv_bfloat162> {
static constexpr int I = I_; static constexpr int I = I_;
static constexpr int J = J_; static constexpr int J = J_;
#if defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#else
static constexpr int ne = I * J / WARP_SIZE; static constexpr int ne = I * J / WARP_SIZE;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
@ -320,6 +402,7 @@ namespace ggml_cuda_mma {
return -1; return -1;
} }
} }
#endif // defined(AMD_WMMA_AVAILABLE)
}; };
template <int I, int J> template <int I, int J>
@ -353,6 +436,8 @@ namespace ggml_cuda_mma {
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0]; xi[0] = xs[0];
} }
#elif defined(AMD_WMMA_AVAILABLE)
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
#else #else
#pragma unroll #pragma unroll
for (int l = 0; l < t.ne; ++l) { for (int l = 0; l < t.ne; ++l) {
@ -639,12 +724,34 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#elif defined(AMD_WMMA_AVAILABLE)
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
#else #else
GGML_UNUSED_VARS(D, A, B); GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // TURING_MMA_AVAILABLE #endif // TURING_MMA_AVAILABLE
} }
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
#if defined(AMD_WMMA_AVAILABLE)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
}
static __device__ __forceinline__ void mma( static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
#if defined(AMD_MFMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE)

View File

@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false; return false;
} }
} else { } else {
if (src1_ncols > 16) { if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
return false; return false;
} }
} }
@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
case GGML_TYPE_F32: case GGML_TYPE_F32:
return ampere_mma_available(cc); return ampere_mma_available(cc);
case GGML_TYPE_F16: case GGML_TYPE_F16:
return volta_mma_available(cc) || turing_mma_available(cc); return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
return ampere_mma_available(cc); return ampere_mma_available(cc) || amd_wmma_available(cc);
default: default:
return false; return false;
} }

View File

@ -2,6 +2,7 @@
#include "mma.cuh" #include "mma.cuh"
#include "common.cuh" #include "common.cuh"
#include "convert.cuh"
using namespace ggml_cuda_mma; using namespace ggml_cuda_mma;
@ -27,20 +28,35 @@ static __global__ void mul_mat_f(
const int stride_col_id, const int stride_row_id, const int stride_col_id, const int stride_row_id,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
typedef tile<16, 8, T> tile_A;
typedef tile<tile_B_I, 8, T> tile_B;
typedef tile<16, tile_C_J, float> tile_C;
constexpr bool a_supported = tile_A::supported();
constexpr bool b_supported = tile_B::supported();
constexpr bool c_supported = tile_C::supported();
constexpr bool supported = a_supported && b_supported && c_supported;
#else
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
constexpr bool supported = I_16_supported || I_32_supported;
if (!I_16_supported && !I_32_supported) {
NO_DEVICE_CODE;
return;
}
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
typedef tile<I_preferred, 8, T> tile_A; typedef tile<I_preferred, 8, T> tile_A;
typedef tile<8, 8, T> tile_B; typedef tile<8, 8, T> tile_B;
typedef tile<I_preferred, 8, float> tile_C; typedef tile<I_preferred, 8, float> tile_C;
#endif // defined(AMD_WMMA_AVAILABLE)
if constexpr (!supported) {
NO_DEVICE_CODE;
return;
}
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4; constexpr int tile_k_padded = warp_size + 4;
@ -161,11 +177,11 @@ static __global__ void mul_mat_f(
if constexpr (!has_ids) { if constexpr (!has_ids) {
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
} else { } else {
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
} }
} }
} else { } else {
@ -239,7 +255,7 @@ static __global__ void mul_mat_f(
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
} }
//This kernel is for larger batch sizes of mul_mat_id //This kernel is for larger batch sizes of mul_mat_id
@ -253,20 +269,35 @@ static __global__ void mul_mat_f_ids(
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) { const uint3 sis1_fd, const uint3 nch_fd) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
typedef tile<16, 8, T> tile_A;
typedef tile<tile_B_I, 8, T> tile_B;
typedef tile<16, tile_C_J, float> tile_C;
constexpr bool a_supported = tile_A::supported();
constexpr bool b_supported = tile_B::supported();
constexpr bool c_supported = tile_C::supported();
constexpr bool supported = a_supported && b_supported && c_supported;
#else
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
constexpr bool supported = I_16_supported || I_32_supported;
if (!I_16_supported && !I_32_supported) { constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
NO_DEVICE_CODE;
return;
}
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
typedef tile<I_preferred, 8, T> tile_A; typedef tile<I_preferred, 8, T> tile_A;
typedef tile<8, 8, T> tile_B; typedef tile<8, 8, T> tile_B;
typedef tile<I_preferred, 8, float> tile_C; typedef tile<I_preferred, 8, float> tile_C;
#endif // defined(AMD_WMMA_AVAILABLE)
if constexpr (!supported) {
NO_DEVICE_CODE;
return;
}
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4; constexpr int tile_k_padded = warp_size + 4;
@ -408,7 +439,7 @@ static __global__ void mul_mat_f_ids(
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) { for (int j0 = 0; j0 < tile_B::I; ++j0) {
const float2 tmp = vals_buf[curr_buf][j0]; const float2 tmp = vals_buf[curr_buf][j0];
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
} }
if (itB + 1 < ntB) { if (itB + 1 < ntB) {
@ -492,7 +523,7 @@ static __global__ void mul_mat_f_ids(
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
} }
template<typename T, int cols_per_block, int nwarps> template<typename T, int cols_per_block, int nwarps>
@ -554,7 +585,8 @@ void mul_mat_f_cuda(
cudaStream_t stream, const mmf_ids_data * ids_data) { cudaStream_t stream, const mmf_ids_data * ids_data) {
typedef tile<16, 8, T> tile_A_16; typedef tile<16, 8, T> tile_A_16;
typedef tile<32, 8, T> tile_A_32; typedef tile<32, 8, T> tile_A_32;
typedef tile< 8, 8, T> tile_B; typedef tile<16, 8, T> tile_B_16;
typedef tile< 8, 8, T> tile_B_8;
GGML_ASSERT(ncols_x % 2 == 0); GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(stride_row % 2 == 0);
@ -581,7 +613,8 @@ void mul_mat_f_cuda(
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;