mma for rdna3

This commit is contained in:
zhang hui 2025-12-13 14:10:56 +08:00
parent 318cb5b80c
commit 074b93146e
1 changed files with 9 additions and 44 deletions

View File

@ -295,12 +295,7 @@ namespace ggml_cuda_mma {
}
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA3)
// RDNA3 has duplicated data as input.
static constexpr int ne = I * J / 32 * 2;
#else
static constexpr int ne = I * J / 32;
#endif // defined(RDNA3)
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
@ -319,14 +314,7 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return l;
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
} else {
NO_DEVICE_CODE;
return -1;
@ -384,42 +372,19 @@ namespace ggml_cuda_mma {
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA3)
// RDNA3 has duplicated data as input.
static constexpr int ne = I * J / 32 * 2;
#else
static constexpr int ne = I * J / 32;
#endif // defined(RDNA3)
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return l;
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
} else {
NO_DEVICE_CODE;
return -1;
}
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
}
#else
static constexpr int ne = I * J / WARP_SIZE;
@ -931,9 +896,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE
}
template <data_layout ABLayout, data_layout DLayout>
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 8, float, DLayout> & D, const tile<16, 8, float, ABLayout> & A, const tile<8, 8, float, ABLayout> & B) {
tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
@ -987,9 +952,9 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
template <data_layout ABLayout, data_layout DLayout>
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 16, float, DLayout> & D, const tile<16, 8, half2, ABLayout> & A, const tile<16, 8, half2, ABLayout> & B) {
tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
@ -1041,9 +1006,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE
}
template <data_layout ABLayout, data_layout DLayout>
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 16, float, DLayout> & D, const tile<16, 8, nv_bfloat162, ABLayout> & A, const tile<16, 8, nv_bfloat162, ABLayout> & B) {
tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;