add missing tile of mfma

This commit is contained in:
zhang hui 2025-12-13 20:26:39 +08:00
parent 6b8ed41f2b
commit 6acad9c759
3 changed files with 30 additions and 47 deletions

View File

@ -93,6 +93,14 @@ namespace ggml_cuda_mma {
dl == DATA_LAYOUT_I_MAJOR_DUAL; dl == DATA_LAYOUT_I_MAJOR_DUAL;
} }
constexpr data_layout get_input_data_layout() {
#if defined(RDNA3)
return DATA_LAYOUT_I_MAJOR_DUAL;
#else
return DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
}
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR> template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
struct tile {}; struct tile {};

View File

@ -35,11 +35,7 @@ static __global__ void mul_mat_f(
constexpr bool is_tf32 = std::is_same_v<T, float>; constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16; constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16; constexpr int tile_C_J = is_tf32 ? 8 : 16;
#if defined(RDNA3) constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout ab_layout = DATA_LAYOUT_I_MAJOR;
#endif // #if defined(RDNA3)
typedef tile<16, 8, T, ab_layout> tile_A; typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B; typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
@ -281,11 +277,7 @@ static __global__ void mul_mat_f_ids(
constexpr bool is_tf32 = std::is_same_v<T, float>; constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16; constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16; constexpr int tile_C_J = is_tf32 ? 8 : 16;
#if defined(RDNA3) constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout ab_layout = DATA_LAYOUT_I_MAJOR;
#endif // #if defined(RDNA3)
typedef tile<16, 8, T, ab_layout> tile_A; typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B; typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;

View File

@ -797,11 +797,7 @@ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA3) constexpr data_layout input_layout = get_input_data_layout();
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
typedef tile<16, 8, int, input_layout> tile_A; typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B; typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
@ -971,11 +967,7 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA3) constexpr data_layout input_layout = get_input_data_layout();
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
typedef tile<16, 8, int, input_layout> tile_A; typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B; typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
@ -1140,10 +1132,11 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE)
typedef tile<16, 8, int> tile_A; constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int> tile_B; typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 16, int> tile_C; typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<64, 2, int> tile_load; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
@ -1189,11 +1182,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
} }
} }
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
#if defined(RDNA3) constexpr data_layout input_layout = get_input_data_layout();
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
@ -1450,10 +1439,11 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE)
typedef tile<16, 8, int> tile_A; constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int> tile_B; typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 16, int> tile_C; typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<64, 2, int> tile_load; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
@ -1516,11 +1506,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
} }
} }
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
#if defined(RDNA3) constexpr data_layout input_layout = get_input_data_layout();
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
@ -2284,10 +2270,11 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE)
typedef tile<16, 8, int> tile_A; constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int> tile_B; typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 16, int> tile_C; typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<64, 2, int> tile_load; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
@ -2335,11 +2322,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
} }
} }
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
#if defined(RDNA3) constexpr data_layout input_layout = get_input_data_layout();
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
#else
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B; typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;