mmq for rdna3

This commit is contained in:
zhang hui 2025-12-13 16:52:25 +08:00
parent 98846cb9ee
commit 62e4954d3f
2 changed files with 36 additions and 12 deletions

View File

@ -529,7 +529,7 @@ namespace ggml_cuda_mma {
static constexpr int ne = I * J / 32 * 2; static constexpr int ne = I * J / 32 * 2;
T x[ne] = {{0.0f, 0.0f}}; T x[ne] = {0};
static constexpr __device__ bool supported() { static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true; if (I == 16 && J == 16) return true;

View File

@ -797,8 +797,13 @@ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_A; #if defined(RDNA3)
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_B; constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
@ -966,8 +971,13 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_A; #if defined(RDNA3)
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_B; constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
@ -1179,8 +1189,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
} }
} }
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A; #if defined(RDNA3)
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B; constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
@ -1501,9 +1516,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
} }
} }
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
#if defined(RDNA3)
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A; constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B; #else
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
@ -2316,8 +2335,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
} }
} }
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A; #if defined(RDNA3)
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B; constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);