HIP: WMMA-MMQ kernels for RDNA 4 (#17156)
* first commit naive test to enable mmq for RDNA4
* adding appropriate WMMA instructions
* git rebase on top of master: fixing the correctness of the mat mul operations, updating layout mappings for RDNA4
* clean up merge conflicts
* add comments and code clean up
* PR clean up, addressed comments
* enable MMQ fallback on RDNA4
* addressed comments: add guards in load generic, separate wmma branch for use_mmq function
* Revert build-xcframework.sh
* Formating: remove trailing whitespace
* revert CMake files
* clean up after rebase: remove duplicated change, revert cmake files
* clean up after rebase: revert changes from build-xcframework.sh
* clean up: remove extra space line in mma.cuh
* Revert "clean up: remove extra space line in mma.cuh"
This reverts commit b39ed57c45.
This commit is contained in:
parent
b61de2b2df
commit
0543f928a3
|
|
@ -73,34 +73,7 @@ namespace ggml_cuda_mma {
|
||||||
static constexpr int I = I_;
|
static constexpr int I = I_;
|
||||||
static constexpr int J = J_;
|
static constexpr int J = J_;
|
||||||
|
|
||||||
#if defined(GGML_USE_HIP)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
#if defined(RDNA4)
|
|
||||||
static constexpr int ne = I * J / 32;
|
|
||||||
T x[ne] = {0};
|
|
||||||
|
|
||||||
static constexpr __device__ bool supported() {
|
|
||||||
if (I == 16 && J == 16) return true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
|
||||||
if constexpr (I == 16 && J == 16) {
|
|
||||||
return 8 * (threadIdx.x / 16) + l;
|
|
||||||
} else {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int l) {
|
|
||||||
if constexpr (I == 16 && J == 16) {
|
|
||||||
return threadIdx.x % 16;
|
|
||||||
} else {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
static constexpr int ne = I * J / 64;
|
static constexpr int ne = I * J / 64;
|
||||||
T x[ne] = {0};
|
T x[ne] = {0};
|
||||||
|
|
||||||
|
|
@ -146,7 +119,6 @@ namespace ggml_cuda_mma {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // defined(RDNA4)
|
|
||||||
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||||
static constexpr int ne = I * J / 32;
|
static constexpr int ne = I * J / 32;
|
||||||
T x[ne] = {0};
|
T x[ne] = {0};
|
||||||
|
|
@ -177,6 +149,34 @@ namespace ggml_cuda_mma {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
|
#if defined(RDNA4)
|
||||||
|
static constexpr int ne = I * J / 32;
|
||||||
|
T x[ne] = {0};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
if (I == 16 && J == 16) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 16) {
|
||||||
|
return 8 * (threadIdx.x / 16) + l;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 16) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#else
|
#else
|
||||||
static constexpr int ne = I * J / 32;
|
static constexpr int ne = I * J / 32;
|
||||||
T x[ne] = {0};
|
T x[ne] = {0};
|
||||||
|
|
@ -437,7 +437,20 @@ namespace ggml_cuda_mma {
|
||||||
xi[0] = xs[0];
|
xi[0] = xs[0];
|
||||||
}
|
}
|
||||||
#elif defined(AMD_WMMA_AVAILABLE)
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
if constexpr (I == 16 && J == 4) {
|
||||||
|
int64_t * xi = (int64_t *) t.x;
|
||||||
|
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||||
|
xi[0] = xs[0];
|
||||||
|
}else if constexpr (I == 16 && J == 8) {
|
||||||
|
int64_t * xi = (int64_t *) t.x;
|
||||||
|
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
|
||||||
|
xi[0] = xs[0];
|
||||||
|
|
||||||
|
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
|
||||||
|
xi[1] = xs1[0];
|
||||||
|
}else{
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < t.ne; ++l) {
|
for (int l = 0; l < t.ne; ++l) {
|
||||||
|
|
@ -772,6 +785,36 @@ namespace ggml_cuda_mma {
|
||||||
acc[0],
|
acc[0],
|
||||||
0, 0, 0);
|
0, 0, 0);
|
||||||
#endif // defined(CDNA3)
|
#endif // defined(CDNA3)
|
||||||
|
|
||||||
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
|
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||||
|
int32x2_t * a_vec = (int32x2_t *) A.x;
|
||||||
|
int32x2_t * b_vec = (int32x2_t *) B.x;
|
||||||
|
|
||||||
|
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
||||||
|
int32x8_t * acc = (int32x8_t *) D.x;
|
||||||
|
|
||||||
|
#if defined(RDNA4)
|
||||||
|
|
||||||
|
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
||||||
|
true,
|
||||||
|
a_vec[0],
|
||||||
|
true,
|
||||||
|
b_vec[0],
|
||||||
|
acc[0],
|
||||||
|
true
|
||||||
|
);
|
||||||
|
|
||||||
|
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
||||||
|
true,
|
||||||
|
a_vec[1],
|
||||||
|
true,
|
||||||
|
b_vec[1],
|
||||||
|
acc[0],
|
||||||
|
true
|
||||||
|
);
|
||||||
|
#endif // defined(RDNA4)
|
||||||
|
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED_VARS(D, A, B);
|
GGML_UNUSED_VARS(D, A, B);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
|
|
@ -798,6 +841,7 @@ namespace ggml_cuda_mma {
|
||||||
acc[0],
|
acc[0],
|
||||||
0, 0, 0);
|
0, 0, 0);
|
||||||
#endif // defined(CDNA3)
|
#endif // defined(CDNA3)
|
||||||
|
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED_VARS(D, A, B);
|
GGML_UNUSED_VARS(D, A, B);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
|
|
@ -842,4 +886,31 @@ namespace ggml_cuda_mma {
|
||||||
mma(D16[1], A16[1], B);
|
mma(D16[1], A16[1], B);
|
||||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void mma(
|
||||||
|
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
|
||||||
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
|
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||||
|
int32x2_t * a_vec = (int32x2_t *) A.x;
|
||||||
|
int32x2_t * b_vec = (int32x2_t *) B.x;
|
||||||
|
|
||||||
|
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
|
||||||
|
int32x8_t * acc = (int32x8_t *) D.x;
|
||||||
|
|
||||||
|
acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
|
||||||
|
true,
|
||||||
|
a_vec[0],
|
||||||
|
true,
|
||||||
|
b_vec[0],
|
||||||
|
acc[0],
|
||||||
|
false
|
||||||
|
);
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(D);
|
||||||
|
GGML_UNUSED(A);
|
||||||
|
GGML_UNUSED(B);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -306,5 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
if (amd_wmma_available(cc)) {
|
||||||
|
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue