mma.cuh for rdna4

This commit is contained in:
zhang hui 2025-12-13 13:42:29 +08:00
parent 380b4c984e
commit 318cb5b80c
2 changed files with 134 additions and 20 deletions

View File

@ -76,9 +76,11 @@ namespace ggml_cuda_mma {
// For the A/C matrices this means I major == row major, J major == column major. // For the A/C matrices this means I major == row major, J major == column major.
// For the B matrix this means I major == column major, J major == row major. // For the B matrix this means I major == column major, J major == row major.
// MIRRORED == Each data value is held exactly once per thread subgroup. // MIRRORED == Each data value is held exactly once per thread subgroup.
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell. DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
DATA_LAYOUT_I_MAJOR_MIRRORED = 10, DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
DATA_LAYOUT_J_MAJOR_MIRRORED = 20, DATA_LAYOUT_I_MAJOR_MIRRORED = 20,
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3.
}; };
// Implemented mma combinations are: // Implemented mma combinations are:
// - (I_MAJOR, I_MAJOR) -> I_MAJOR // - (I_MAJOR, I_MAJOR) -> I_MAJOR
@ -458,6 +460,46 @@ namespace ggml_cuda_mma {
#endif // defined(AMD_WMMA_AVAILABLE) #endif // defined(AMD_WMMA_AVAILABLE)
}; };
template <int I_, int J_, typename T>
struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
static constexpr int ne = I * J / 32;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
#if defined(RDNA4)
return 8 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return 2 * l + (threadIdx.x / 16);
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
};
template <int I_, int J_> template <int I_, int J_>
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> { struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
static constexpr int I = I_; static constexpr int I = I_;
@ -524,6 +566,63 @@ namespace ggml_cuda_mma {
} }
}; };
template <int I_, int J_>
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_DUAL> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
static constexpr int ne = I * J / 32 * 2;
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
};
template <int I_, int J_>
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_DUAL> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
static constexpr int ne = I * J / 32 * 2;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_DUAL>::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_DUAL>::get_i(l);
}
static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_DUAL>::get_j(l);
}
};
#if defined(TURING_MMA_AVAILABLE) #if defined(TURING_MMA_AVAILABLE)
template <int I, int J> template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
@ -660,9 +759,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE #endif // TURING_MMA_AVAILABLE
} }
template <typename T> template <typename T, data_layout dl>
static __device__ __forceinline__ void load_ldmatrix( static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
#if defined(TURING_MMA_AVAILABLE) #if defined(TURING_MMA_AVAILABLE)
int * xi = (int * ) t.x; int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
@ -832,8 +931,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE #endif // TURING_MMA_AVAILABLE
} }
template <data_layout ABLayout, data_layout DLayout>
static __device__ __forceinline__ void mma( static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) { tile<16, 8, float, DLayout> & D, const tile<16, 8, float, ABLayout> & A, const tile<8, 8, float, ABLayout> & B) {
#ifdef AMPERE_MMA_AVAILABLE #ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x; const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x; const int * Bxi = (const int *) B.x;
@ -887,8 +987,9 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE #endif // AMPERE_MMA_AVAILABLE
} }
template <data_layout ABLayout, data_layout DLayout>
static __device__ __forceinline__ void mma( static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { tile<16, 16, float, DLayout> & D, const tile<16, 8, half2, ABLayout> & A, const tile<16, 8, half2, ABLayout> & B) {
#ifdef TURING_MMA_AVAILABLE #ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x; const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x; const int * Bxi = (const int *) B.x;
@ -940,8 +1041,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE #endif // TURING_MMA_AVAILABLE
} }
template <data_layout ABLayout, data_layout DLayout>
static __device__ __forceinline__ void mma( static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { tile<16, 16, float, DLayout> & D, const tile<16, 8, nv_bfloat162, ABLayout> & A, const tile<16, 8, nv_bfloat162, ABLayout> & B) {
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4) #if defined(RDNA4)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;

View File

@ -32,11 +32,17 @@ static __global__ void mul_mat_f(
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it. // Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16; constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16; constexpr int tile_B_I = is_tf32 ? 8 : 16;
typedef tile<16, 8, T> tile_A; constexpr int tile_C_J = is_tf32 ? 8 : 16;
typedef tile<tile_B_I, 8, T> tile_B; #if defined(RDNA3)
typedef tile<16, tile_C_J, float> tile_C; constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout ab_layout = DATA_LAYOUT_I_MAJOR;
#endif // #if defined(RDNA3)
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else #else
#ifdef VOLTA_MMA_AVAILABLE #ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else { if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
@ -272,11 +278,17 @@ static __global__ void mul_mat_f_ids(
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it. // Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16; constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16; constexpr int tile_B_I = is_tf32 ? 8 : 16;
typedef tile<16, 8, T> tile_A; constexpr int tile_C_J = is_tf32 ? 8 : 16;
typedef tile<tile_B_I, 8, T> tile_B; #if defined(RDNA3)
typedef tile<16, tile_C_J, float> tile_C; constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout ab_layout = DATA_LAYOUT_I_MAJOR;
#endif // #if defined(RDNA3)
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else #else
#ifdef VOLTA_MMA_AVAILABLE #ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else { if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {