From 62e4954d3f35d9f8e86139595ac0d1f0c912e7e8 Mon Sep 17 00:00:00 2001 From: zhang hui Date: Sat, 13 Dec 2025 16:52:25 +0800 Subject: [PATCH] mmq for rdna3 --- ggml/src/ggml-cuda/mma.cuh | 2 +- ggml/src/ggml-cuda/mmq.cuh | 46 +++++++++++++++++++++++++++++--------- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 8a53f19341..e56da4329b 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -529,7 +529,7 @@ namespace ggml_cuda_mma { static constexpr int ne = I * J / 32 * 2; - T x[ne] = {{0.0f, 0.0f}}; + T x[ne] = {0}; static constexpr __device__ bool supported() { if (I == 16 && J == 16) return true; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index e748f24c3a..d1c75a22e1 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -797,8 +797,13 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_A; - typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_B; +#if defined(RDNA3) + constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL; +#else + constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR; +#endif // defined(RDNA3) + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); @@ -966,8 +971,13 @@ template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_A; - typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_B; +#if defined(RDNA3) + constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL; +#else + constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR; +#endif // defined(RDNA3) + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); @@ -1179,8 +1189,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } } #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles - typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A; - typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B; +#if defined(RDNA3) + constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL; +#else + constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR; +#endif // defined(RDNA3) + typedef tile<16, 4, int, input_layout> tile_A; + typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); @@ -1501,9 +1516,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } } #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles - - typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A; - typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B; +#if defined(RDNA3) + constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL; +#else + constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR; +#endif // defined(RDNA3) + typedef tile<16, 4, int, input_layout> tile_A; + typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); @@ -2316,8 +2335,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } } #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles - typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A; - typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B; +#if defined(RDNA3) + constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL; +#else + constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR; +#endif // defined(RDNA3) + typedef tile<16, 4, int, input_layout> tile_A; + typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x);