mmq for rdna3
This commit is contained in:
parent
98846cb9ee
commit
62e4954d3f
|
|
@ -529,7 +529,7 @@ namespace ggml_cuda_mma {
|
|||
|
||||
static constexpr int ne = I * J / 32 * 2;
|
||||
|
||||
T x[ne] = {{0.0f, 0.0f}};
|
||||
T x[ne] = {0};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 16) return true;
|
||||
|
|
|
|||
|
|
@ -797,8 +797,13 @@ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
|
|||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
#if defined(RDNA3)
|
||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
|
||||
#else
|
||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
|
||||
#endif // defined(RDNA3)
|
||||
typedef tile<16, 8, int, input_layout> tile_A;
|
||||
typedef tile<16, 8, int, input_layout> tile_B;
|
||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
|
|
@ -966,8 +971,13 @@ template <int mmq_x, int mmq_y>
|
|||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
#if defined(RDNA3)
|
||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
|
||||
#else
|
||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
|
||||
#endif // defined(RDNA3)
|
||||
typedef tile<16, 8, int, input_layout> tile_A;
|
||||
typedef tile<16, 8, int, input_layout> tile_B;
|
||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
|
|
@ -1179,8 +1189,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
#if defined(RDNA3)
|
||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
|
||||
#else
|
||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
|
||||
#endif // defined(RDNA3)
|
||||
typedef tile<16, 4, int, input_layout> tile_A;
|
||||
typedef tile<16, 4, int, input_layout> tile_B;
|
||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
|
|
@ -1501,9 +1516,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
#if defined(RDNA3)
|
||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
|
||||
#else
|
||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
|
||||
#endif // defined(RDNA3)
|
||||
typedef tile<16, 4, int, input_layout> tile_A;
|
||||
typedef tile<16, 4, int, input_layout> tile_B;
|
||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
|
|
@ -2316,8 +2335,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
#if defined(RDNA3)
|
||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR_DUAL;
|
||||
#else
|
||||
constexpr data_layout input_layout = DATA_LAYOUT_I_MAJOR;
|
||||
#endif // defined(RDNA3)
|
||||
typedef tile<16, 4, int, input_layout> tile_A;
|
||||
typedef tile<16, 4, int, input_layout> tile_B;
|
||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
|
|
|
|||
Loading…
Reference in New Issue