add missing tile of mfma
This commit is contained in:
parent
6b8ed41f2b
commit
6acad9c759
|
|
@ -93,6 +93,14 @@ namespace ggml_cuda_mma {
|
||||||
dl == DATA_LAYOUT_I_MAJOR_DUAL;
|
dl == DATA_LAYOUT_I_MAJOR_DUAL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr data_layout get_input_data_layout() {
|
||||||
|
#if defined(RDNA3)
|
||||||
|
return DATA_LAYOUT_I_MAJOR_DUAL;
|
||||||
|
#else
|
||||||
|
return DATA_LAYOUT_I_MAJOR;
|
||||||
|
#endif // defined(RDNA3)
|
||||||
|
}
|
||||||
|
|
||||||
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
||||||
struct tile {};
|
struct tile {};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,11 +35,7 @@ static __global__ void mul_mat_f(
|
||||||
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||||
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||||
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||||
#if defined(RDNA3)
|
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||||
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : DATA_LAYOUT_I_MAJOR_DUAL;
|
|
||||||
#else
|
|
||||||
constexpr data_layout ab_layout = DATA_LAYOUT_I_MAJOR;
|
|
||||||
#endif // #if defined(RDNA3)
|
|
||||||
typedef tile<16, 8, T, ab_layout> tile_A;
|
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||||
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||||
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
@ -281,11 +277,7 @@ static __global__ void mul_mat_f_ids(
|
||||||
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||||
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||||
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||||
#if defined(RDNA3)
|
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||||
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : DATA_LAYOUT_I_MAJOR_DUAL;
|
|
||||||
#else
|
|
||||||
constexpr data_layout ab_layout = DATA_LAYOUT_I_MAJOR;
|
|
||||||
#endif // #if defined(RDNA3)
|
|
||||||
typedef tile<16, 8, T, ab_layout> tile_A;
|
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||||
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||||
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
|
||||||
|
|
@ -797,11 +797,7 @@ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
|
||||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||||
#if defined(RDNA3)
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
|
|
||||||
#else
|
|
||||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
|
|
||||||
#endif // defined(RDNA3)
|
|
||||||
typedef tile<16, 8, int, input_layout> tile_A;
|
typedef tile<16, 8, int, input_layout> tile_A;
|
||||||
typedef tile<16, 8, int, input_layout> tile_B;
|
typedef tile<16, 8, int, input_layout> tile_B;
|
||||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
@ -971,11 +967,7 @@ template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||||
#if defined(RDNA3)
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
|
|
||||||
#else
|
|
||||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
|
|
||||||
#endif // defined(RDNA3)
|
|
||||||
typedef tile<16, 8, int, input_layout> tile_A;
|
typedef tile<16, 8, int, input_layout> tile_A;
|
||||||
typedef tile<16, 8, int, input_layout> tile_B;
|
typedef tile<16, 8, int, input_layout> tile_B;
|
||||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
@ -1140,10 +1132,11 @@ template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
typedef tile<16, 8, int> tile_A;
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 8, int> tile_B;
|
typedef tile<16, 8, int, input_layout> tile_A;
|
||||||
typedef tile<16, 16, int> tile_C;
|
typedef tile<16, 8, int, input_layout> tile_B;
|
||||||
typedef tile<64, 2, int> tile_load;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
typedef tile<64, 2, int, input_layout> tile_load;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
|
|
@ -1189,11 +1182,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||||
#if defined(RDNA3)
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
|
|
||||||
#else
|
|
||||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
|
|
||||||
#endif // defined(RDNA3)
|
|
||||||
typedef tile<16, 4, int, input_layout> tile_A;
|
typedef tile<16, 4, int, input_layout> tile_A;
|
||||||
typedef tile<16, 4, int, input_layout> tile_B;
|
typedef tile<16, 4, int, input_layout> tile_B;
|
||||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
@ -1450,10 +1439,11 @@ template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
typedef tile<16, 8, int> tile_A;
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 8, int> tile_B;
|
typedef tile<16, 8, int, input_layout> tile_A;
|
||||||
typedef tile<16, 16, int> tile_C;
|
typedef tile<16, 8, int, input_layout> tile_B;
|
||||||
typedef tile<64, 2, int> tile_load;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
typedef tile<64, 2, int, input_layout> tile_load;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
|
|
@ -1516,11 +1506,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||||
#if defined(RDNA3)
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
|
|
||||||
#else
|
|
||||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
|
|
||||||
#endif // defined(RDNA3)
|
|
||||||
typedef tile<16, 4, int, input_layout> tile_A;
|
typedef tile<16, 4, int, input_layout> tile_A;
|
||||||
typedef tile<16, 4, int, input_layout> tile_B;
|
typedef tile<16, 4, int, input_layout> tile_B;
|
||||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
@ -2284,10 +2270,11 @@ template <int mmq_x, int mmq_y>
|
||||||
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
typedef tile<16, 8, int> tile_A;
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
typedef tile<16, 8, int> tile_B;
|
typedef tile<16, 8, int, input_layout> tile_A;
|
||||||
typedef tile<16, 16, int> tile_C;
|
typedef tile<16, 8, int, input_layout> tile_B;
|
||||||
typedef tile<64, 2, int> tile_load;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
typedef tile<64, 2, int, input_layout> tile_load;
|
||||||
|
|
||||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||||
constexpr int rows_per_warp = granularity;
|
constexpr int rows_per_warp = granularity;
|
||||||
|
|
@ -2335,11 +2322,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||||
#if defined(RDNA3)
|
constexpr data_layout input_layout = get_input_data_layout();
|
||||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
|
|
||||||
#else
|
|
||||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
|
|
||||||
#endif // defined(RDNA3)
|
|
||||||
typedef tile<16, 4, int, input_layout> tile_A;
|
typedef tile<16, 4, int, input_layout> tile_A;
|
||||||
typedef tile<16, 4, int, input_layout> tile_B;
|
typedef tile<16, 4, int, input_layout> tile_B;
|
||||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue