From 6acad9c7599b237232da78f4097f78f762d3366e Mon Sep 17 00:00:00 2001 From: zhang hui Date: Sat, 13 Dec 2025 20:26:39 +0800 Subject: [PATCH] add missing tile of mfma --- ggml/src/ggml-cuda/mma.cuh | 8 ++++++ ggml/src/ggml-cuda/mmf.cuh | 12 ++------ ggml/src/ggml-cuda/mmq.cuh | 57 +++++++++++++------------------------- 3 files changed, 30 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 70bc60d320..74e58c322a 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -93,6 +93,14 @@ namespace ggml_cuda_mma { dl == DATA_LAYOUT_I_MAJOR_DUAL; } + constexpr data_layout get_input_data_layout() { +#if defined(RDNA3) + return DATA_LAYOUT_I_MAJOR_DUAL; +#else + return DATA_LAYOUT_I_MAJOR; +#endif // defined(RDNA3) + } + template struct tile {}; diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index e1b9c6a6b7..e36730948f 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -35,11 +35,7 @@ static __global__ void mul_mat_f( constexpr bool is_tf32 = std::is_same_v; constexpr int tile_B_I = is_tf32 ? 8 : 16; constexpr int tile_C_J = is_tf32 ? 8 : 16; -#if defined(RDNA3) - constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : DATA_LAYOUT_I_MAJOR_DUAL; -#else - constexpr data_layout ab_layout = DATA_LAYOUT_I_MAJOR; -#endif // #if defined(RDNA3) + constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); typedef tile<16, 8, T, ab_layout> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; @@ -281,11 +277,7 @@ static __global__ void mul_mat_f_ids( constexpr bool is_tf32 = std::is_same_v; constexpr int tile_B_I = is_tf32 ? 8 : 16; constexpr int tile_C_J = is_tf32 ? 8 : 16; -#if defined(RDNA3) - constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : DATA_LAYOUT_I_MAJOR_DUAL; -#else - constexpr data_layout ab_layout = DATA_LAYOUT_I_MAJOR; -#endif // #if defined(RDNA3) + constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); typedef tile<16, 8, T, ab_layout> tile_A; typedef tile tile_B; typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index a91d95df15..fa8a72c9c1 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -797,11 +797,7 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA3) - constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL; -#else - constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) + constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 8, int, input_layout> tile_A; typedef tile<16, 8, int, input_layout> tile_B; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; @@ -971,11 +967,7 @@ template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA3) - constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL; -#else - constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) + constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 8, int, input_layout> tile_A; typedef tile<16, 8, int, input_layout> tile_B; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; @@ -1140,10 +1132,11 @@ template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; - typedef tile<64, 2, int> tile_load; + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + typedef tile<64, 2, int, input_layout> tile_load; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; @@ -1189,11 +1182,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } } #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles -#if defined(RDNA3) - constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL; -#else - constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) + constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; @@ -1450,10 +1439,11 @@ template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; - typedef tile<64, 2, int> tile_load; + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + typedef tile<64, 2, int, input_layout> tile_load; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; @@ -1516,11 +1506,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } } #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles -#if defined(RDNA3) - constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL; -#else - constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) + constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; @@ -2284,10 +2270,11 @@ template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; - typedef tile<64, 2, int> tile_load; + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + typedef tile<64, 2, int, input_layout> tile_load; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; @@ -2335,11 +2322,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } } #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles -#if defined(RDNA3) - constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL; -#else - constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) + constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;