mmq for rdna4
This commit is contained in:
parent
074b93146e
commit
98846cb9ee
|
|
@ -87,6 +87,12 @@ namespace ggml_cuda_mma {
|
|||
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
||||
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
||||
|
||||
constexpr bool is_i_major(const data_layout dl) {
|
||||
return dl == DATA_LAYOUT_I_MAJOR ||
|
||||
dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
|
||||
dl == DATA_LAYOUT_I_MAJOR_DUAL;
|
||||
}
|
||||
|
||||
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
||||
struct tile {};
|
||||
|
||||
|
|
@ -173,28 +179,19 @@ namespace ggml_cuda_mma {
|
|||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(RDNA4)
|
||||
static constexpr int ne = I * J / 32;
|
||||
#elif defined(RDNA3)
|
||||
static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16;
|
||||
#endif // defined(RDNA4)
|
||||
T x[ne] = {0};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 16) return true;
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 16 && J == 4) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 16) {
|
||||
#if defined(RDNA4)
|
||||
return 8 * (threadIdx.x / 16) + l;
|
||||
#elif defined(RDNA3)
|
||||
return 2 * l + (threadIdx.x / 16);
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
#endif // defined(RDNA4)
|
||||
if constexpr (supported()) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
|
|
@ -203,7 +200,17 @@ namespace ggml_cuda_mma {
|
|||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 16) {
|
||||
return threadIdx.x % 16;
|
||||
// matrix C
|
||||
#if defined(RDNA3)
|
||||
return 2 * l + (threadIdx.x / 16);
|
||||
#else
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
#endif // defined(RDNA3)
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
// mmq input for RDNA4
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
} else if constexpr (I == 16 && J == 4) {
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
|
|
@ -440,28 +447,11 @@ namespace ggml_cuda_mma {
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 16) {
|
||||
#if defined(RDNA4)
|
||||
return 8 * (threadIdx.x / 16) + l;
|
||||
#elif defined(RDNA3)
|
||||
return 2 * l + (threadIdx.x / 16);
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
#endif // defined(RDNA4)
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 16) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
return -1;
|
||||
}
|
||||
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -531,23 +521,25 @@ namespace ggml_cuda_mma {
|
|||
}
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_DUAL> {
|
||||
template <int I_, int J_, typename T>
|
||||
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
|
||||
|
||||
static constexpr int ne = I * J / 32 * 2;
|
||||
|
||||
half2 x[ne] = {{0.0f, 0.0f}};
|
||||
T x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 16 && J == 16) return true;
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 16 && J == 4) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
if constexpr (supported()) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
|
|
@ -556,7 +548,7 @@ namespace ggml_cuda_mma {
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
if constexpr (supported()) {
|
||||
return l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
|
|
@ -565,29 +557,6 @@ namespace ggml_cuda_mma {
|
|||
}
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_DUAL> {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
|
||||
|
||||
static constexpr int ne = I * J / 32 * 2;
|
||||
|
||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||
|
||||
static constexpr __device__ bool supported() {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_DUAL>::supported();
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_DUAL>::get_i(l);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_DUAL>::get_j(l);
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
template <int I, int J>
|
||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||
|
|
@ -638,50 +607,25 @@ namespace ggml_cuda_mma {
|
|||
xi[0] = xs[0];
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||
#if defined(RDNA4)
|
||||
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||
#elif defined(RDNA3)
|
||||
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2));
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(RDNA4)
|
||||
} else if constexpr (std::is_same_v<T, int>) {
|
||||
if constexpr (I == 16 && J == 4) {
|
||||
int64_t * xi = (int64_t *) t.x;
|
||||
#if defined(RDNA4)
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
#elif defined(RDNA3)
|
||||
static_assert(tile<I,J,T>::ne >= 4, "fragment too small");
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
|
||||
xi[0] = xs[0];
|
||||
xi[1] = xs[1];
|
||||
#endif // defined(RDNA4)
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
int64_t * xi = (int64_t *) t.x;
|
||||
#if defined(RDNA4)
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
|
||||
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
|
||||
xi[1] = xs1[0];
|
||||
#elif defined(RDNA3)
|
||||
static_assert(tile<I,J,T>::ne >= 8, "fragment too small");
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
|
||||
// contiguous four 64-bit chunks per lane for the wider RDNA3 fragment
|
||||
xi[0] = xs[0];
|
||||
xi[1] = xs[1];
|
||||
const int64_t * xs1 = xs + 2;
|
||||
xi[2] = xs1[0];
|
||||
xi[3] = xs1[1];
|
||||
#endif // defined(RDNA4)
|
||||
// All wmma layout has continues data when i-major.
|
||||
if constexpr (is_i_major(dl)) {
|
||||
// the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
|
||||
constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
|
||||
if constexpr (sizeof(t.x) > aligned_copy_bytes) {
|
||||
static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
|
||||
constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < aligned_copy_count; ++i) {
|
||||
ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
|
||||
}
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||
}
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
|
|
@ -1034,8 +978,9 @@ namespace ggml_cuda_mma {
|
|||
#endif // AMPERE_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <data_layout dl_d, data_layout dl_ab>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
||||
tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||
int32x4_t * acc = (int32x4_t *) D.x;
|
||||
|
|
@ -1189,8 +1134,9 @@ namespace ggml_cuda_mma {
|
|||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
|
||||
template <data_layout dl_d, data_layout dl_ab>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
||||
int32x8_t * acc = (int32x8_t *) D.x;
|
||||
|
|
|
|||
|
|
@ -797,9 +797,9 @@ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
|
|||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile<16, 8, int> tile_B;
|
||||
typedef tile<16, 16, int> tile_C;
|
||||
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = granularity;
|
||||
|
|
@ -966,9 +966,9 @@ template <int mmq_x, int mmq_y>
|
|||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile<16, 8, int> tile_B;
|
||||
typedef tile<16, 16, int> tile_C;
|
||||
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, int, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = granularity;
|
||||
|
|
@ -1179,9 +1179,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||
typedef tile<16, 4, int> tile_A;
|
||||
typedef tile<16, 4, int> tile_B;
|
||||
typedef tile<16, 16, int> tile_C;
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = granularity;
|
||||
|
|
@ -1502,9 +1502,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||
|
||||
typedef tile<16, 4, int> tile_A;
|
||||
typedef tile<16, 4, int> tile_B;
|
||||
typedef tile<16, 16, int> tile_C;
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = granularity;
|
||||
|
|
@ -1570,7 +1570,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|||
}
|
||||
#elif defined(TURING_MMA_AVAILABLE)
|
||||
|
||||
typedef tile<16, 4, int> tile_A;
|
||||
16, 4, int> tile_A;
|
||||
typedef tile<16, 8, int> tile_A_8;
|
||||
typedef tile< 8, 4, int> tile_B;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
|
@ -2316,9 +2316,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
|
||||
typedef tile<16, 4, int> tile_A;
|
||||
typedef tile<16, 4, int> tile_B;
|
||||
typedef tile<16, 16, int> tile_C;
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 4, int, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = granularity;
|
||||
|
|
@ -3015,7 +3015,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
|||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
constexpr int tileC_IJ = mmq_get_granularity_device(0);
|
||||
typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
|
||||
typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
constexpr int rows_per_warp = granularity;
|
||||
#else
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
|
|
|||
Loading…
Reference in New Issue