diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index dcfa40f4d5..16f28f6ab9 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -76,9 +76,11 @@ namespace ggml_cuda_mma { // For the A/C matrices this means I major == row major, J major == column major. // For the B matrix this means I major == column major, J major == row major. // MIRRORED == Each data value is held exactly once per thread subgroup. - DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell. - DATA_LAYOUT_I_MAJOR_MIRRORED = 10, - DATA_LAYOUT_J_MAJOR_MIRRORED = 20, + DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA. + DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3. + DATA_LAYOUT_I_MAJOR_MIRRORED = 20, + DATA_LAYOUT_J_MAJOR_MIRRORED = 30, + DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3. }; // Implemented mma combinations are: // - (I_MAJOR, I_MAJOR) -> I_MAJOR @@ -458,6 +460,46 @@ namespace ggml_cuda_mma { #endif // defined(AMD_WMMA_AVAILABLE) }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR; + + static constexpr int ne = I * J / 32; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 16) { +#if defined(RDNA4) + return 8 * (threadIdx.x / 16) + l; +#elif defined(RDNA3) + return 2 * l + (threadIdx.x / 16); +#else + NO_DEVICE_CODE; + return -1; +#endif // defined(RDNA4) + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + template struct tile { static constexpr int I = I_; @@ -524,6 +566,63 @@ namespace ggml_cuda_mma { } }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL; + + static constexpr int ne = I * J / 32 * 2; + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL; + + static constexpr int ne = I * J / 32 * 2; + + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile::get_j(l); + } + }; + #if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { @@ -660,9 +759,9 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } - template + template static __device__ __forceinline__ void load_ldmatrix( - tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { + tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) { #if defined(TURING_MMA_AVAILABLE) int * xi = (int * ) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); @@ -832,8 +931,9 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } + template static __device__ __forceinline__ void mma( - tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) { + tile<16, 8, float, DLayout> & D, const tile<16, 8, float, ABLayout> & A, const tile<8, 8, float, ABLayout> & B) { #ifdef AMPERE_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -886,9 +986,10 @@ namespace ggml_cuda_mma { NO_DEVICE_CODE; #endif // AMPERE_MMA_AVAILABLE } - + + template static __device__ __forceinline__ void mma( - tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { + tile<16, 16, float, DLayout> & D, const tile<16, 8, half2, ABLayout> & A, const tile<16, 8, half2, ABLayout> & B) { #ifdef TURING_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -939,9 +1040,10 @@ namespace ggml_cuda_mma { NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } - + + template static __device__ __forceinline__ void mma( - tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { + tile<16, 16, float, DLayout> & D, const tile<16, 8, nv_bfloat162, ABLayout> & A, const tile<16, 8, nv_bfloat162, ABLayout> & B) { #if defined(AMD_WMMA_AVAILABLE) #if defined(RDNA4) using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index e1c695c5c0..e1b9c6a6b7 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -32,11 +32,17 @@ static __global__ void mul_mat_f( #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr int tile_B_I = std::is_same_v ? 8 : 16; - constexpr int tile_C_J = std::is_same_v ? 8 : 16; - typedef tile<16, 8, T> tile_A; - typedef tile tile_B; - typedef tile<16, tile_C_J, float> tile_C; + constexpr bool is_tf32 = std::is_same_v; + constexpr int tile_B_I = is_tf32 ? 8 : 16; + constexpr int tile_C_J = is_tf32 ? 8 : 16; +#if defined(RDNA3) + constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : DATA_LAYOUT_I_MAJOR_DUAL; +#else + constexpr data_layout ab_layout = DATA_LAYOUT_I_MAJOR; +#endif // #if defined(RDNA3) + typedef tile<16, 8, T, ab_layout> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { @@ -272,11 +278,17 @@ static __global__ void mul_mat_f_ids( #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr int tile_B_I = std::is_same_v ? 8 : 16; - constexpr int tile_C_J = std::is_same_v ? 8 : 16; - typedef tile<16, 8, T> tile_A; - typedef tile tile_B; - typedef tile<16, tile_C_J, float> tile_C; + constexpr bool is_tf32 = std::is_same_v; + constexpr int tile_B_I = is_tf32 ? 8 : 16; + constexpr int tile_C_J = is_tf32 ? 8 : 16; +#if defined(RDNA3) + constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : DATA_LAYOUT_I_MAJOR_DUAL; +#else + constexpr data_layout ab_layout = DATA_LAYOUT_I_MAJOR; +#endif // #if defined(RDNA3) + typedef tile<16, 8, T, ab_layout> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else {