only enable m16n8k16 on ampere or above
This commit is contained in:
parent
ea438d8b0e
commit
9f498d29f1
|
|
@ -343,15 +343,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k, unsigned int smem_stride>
|
||||
__device__ __forceinline__ void ldmatrix_a(
|
||||
const half* src,
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8]
|
||||
#else
|
||||
half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]
|
||||
#endif
|
||||
){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8");
|
||||
// static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2");
|
||||
#else
|
||||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||
#endif
|
||||
|
||||
// uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast<uint32_t(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]>(reg);
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast<uint32_t(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]>(reg);
|
||||
#else
|
||||
uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast<uint32_t(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]>(reg);
|
||||
#endif
|
||||
unsigned int logical_offset = (threadIdx.x % 32) * smem_stride;
|
||||
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4);
|
||||
swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2);
|
||||
|
|
@ -392,38 +402,8 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
|
||||
src_addr ^= 0b10000;
|
||||
|
||||
// // 1
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1])
|
||||
// : "r"(src_addr)
|
||||
// );
|
||||
|
||||
// // 1
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1])
|
||||
// : "r"(src_addr + 32 * smem_stride_)
|
||||
// );
|
||||
|
||||
// // 1
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1])
|
||||
// : "r"(src_addr + 64 * smem_stride_)
|
||||
// );
|
||||
|
||||
// // 1
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1])
|
||||
// : "r"(src_addr + 96 * smem_stride_)
|
||||
// );
|
||||
|
||||
// 1
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -455,41 +435,43 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
: "r"(src_addr + 96 * smem_stride_)
|
||||
);
|
||||
|
||||
#else
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
// 1
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
// 1
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1])
|
||||
: "r"(src_addr + 64 * smem_stride_)
|
||||
);
|
||||
|
||||
// 1
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1])
|
||||
: "r"(src_addr + 96 * smem_stride_)
|
||||
);
|
||||
#endif
|
||||
|
||||
src_addr ^= 0b110000;
|
||||
|
||||
// // 2
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1])
|
||||
// : "r"(src_addr)
|
||||
// );
|
||||
|
||||
// // 2
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1])
|
||||
// : "r"(src_addr + 32 * smem_stride_)
|
||||
// );
|
||||
|
||||
// // 2
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1])
|
||||
// : "r"(src_addr + 64 * smem_stride_)
|
||||
// );
|
||||
|
||||
// // 2
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1])
|
||||
// : "r"(src_addr + 96 * smem_stride_)
|
||||
// );
|
||||
|
||||
// 2
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -520,42 +502,42 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
: "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1])
|
||||
: "r"(src_addr + 96 * smem_stride_)
|
||||
);
|
||||
#else
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
// 2
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
// 2
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1])
|
||||
: "r"(src_addr + 64 * smem_stride_)
|
||||
);
|
||||
|
||||
// 2
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1])
|
||||
: "r"(src_addr + 96 * smem_stride_)
|
||||
);
|
||||
#endif
|
||||
src_addr ^= 0b10000;
|
||||
|
||||
// // 3
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1])
|
||||
// : "r"(src_addr)
|
||||
// );
|
||||
|
||||
// // 3
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1])
|
||||
// : "r"(src_addr + 32 * smem_stride_)
|
||||
// );
|
||||
|
||||
// // 3
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1])
|
||||
// : "r"(src_addr + 64 * smem_stride_)
|
||||
// );
|
||||
|
||||
// // 3
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1])
|
||||
// : "r"(src_addr + 96 * smem_stride_)
|
||||
// );
|
||||
|
||||
// 3
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -586,6 +568,38 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
: "=r"(reg_[6][1][2]), "=r"(reg_[6][1][3]), "=r"(reg_[7][1][2]), "=r"(reg_[7][1][3])
|
||||
: "r"(src_addr + 96 * smem_stride_)
|
||||
);
|
||||
#else
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
// 3
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
// 3
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1])
|
||||
: "r"(src_addr + 64 * smem_stride_)
|
||||
);
|
||||
|
||||
// 3
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1])
|
||||
: "r"(src_addr + 96 * smem_stride_)
|
||||
);
|
||||
#endif
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(reg);
|
||||
|
|
@ -596,17 +610,26 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
template <unsigned int mma_tiles_per_warp_k, unsigned int mma_tiles_per_warp_n, unsigned int smem_stride>
|
||||
__device__ __forceinline__ void ldmatrix_b(
|
||||
const half* src,
|
||||
// half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4]
|
||||
#else
|
||||
half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]
|
||||
#endif
|
||||
){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
|
||||
// static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2");
|
||||
#else
|
||||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||
#endif
|
||||
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
|
||||
|
||||
// uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
uint32_t (®_) [2][8][2] = reinterpret_cast<uint32_t(&)[2][8][2]>(reg);
|
||||
#else
|
||||
uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
||||
#endif
|
||||
unsigned int logical_offset = (threadIdx.x % 32) * smem_stride;
|
||||
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4);
|
||||
swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2);
|
||||
|
|
@ -614,21 +637,7 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
|
||||
|
||||
// 0
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3])
|
||||
// : "r"(src_addr)
|
||||
// );
|
||||
|
||||
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7])
|
||||
// : "r"(src_addr + 32 * smem_stride_)
|
||||
// );
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -636,30 +645,32 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][4][0]), "=r"(reg_[0][5][0]), "=r"(reg_[0][6][0]), "=r"(reg_[0][7][0])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
#else
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
#endif
|
||||
|
||||
src_addr ^= 0b10000;
|
||||
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3])
|
||||
// : "r"(src_addr)
|
||||
// );
|
||||
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7])
|
||||
// : "r"(src_addr + 32 * smem_stride_)
|
||||
// );
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -673,23 +684,25 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
: "=r"(reg_[0][4][1]), "=r"(reg_[0][5][1]), "=r"(reg_[0][6][1]), "=r"(reg_[0][7][1])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
#else
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
#endif
|
||||
|
||||
src_addr ^= 0b110000;
|
||||
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3])
|
||||
// : "r"(src_addr)
|
||||
// );
|
||||
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7])
|
||||
// : "r"(src_addr + 32 * smem_stride_)
|
||||
// );
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -703,23 +716,25 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
: "=r"(reg_[1][4][0]), "=r"(reg_[1][5][0]), "=r"(reg_[1][6][0]), "=r"(reg_[1][7][0])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
#else
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
#endif
|
||||
|
||||
src_addr ^= 0b10000;
|
||||
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3])
|
||||
// : "r"(src_addr)
|
||||
// );
|
||||
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7])
|
||||
// : "r"(src_addr + 32 * smem_stride_)
|
||||
// );
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -733,6 +748,21 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
: "=r"(reg_[1][4][1]), "=r"(reg_[1][5][1]), "=r"(reg_[1][6][1]), "=r"(reg_[1][7][1])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
#else
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
#endif
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(reg);
|
||||
|
|
@ -760,7 +790,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
const unsigned int NKPQ = param.n * KPQ;
|
||||
|
||||
// loop bounds, constexpr where possible allows for loop unrolling
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
constexpr unsigned int mma_tiles_per_warp_k = 2;
|
||||
#else
|
||||
constexpr unsigned int mma_tiles_per_warp_k = 4;
|
||||
#endif
|
||||
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
|
||||
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
|
||||
const unsigned int z = blockIdx.z;
|
||||
|
|
@ -787,14 +821,23 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
// declare register storage
|
||||
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
|
||||
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4];
|
||||
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2];
|
||||
#else
|
||||
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2];
|
||||
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
|
||||
#endif
|
||||
|
||||
// convenience cast to half for register storage
|
||||
half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]>(acc_register);
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8]>(A_register);
|
||||
half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4]>(B_register);
|
||||
|
||||
#else
|
||||
half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]>(A_register);
|
||||
half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast<half(&)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]>(B_register);
|
||||
#endif
|
||||
// accumulators start at 0
|
||||
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){
|
||||
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){
|
||||
|
|
@ -844,17 +887,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){
|
||||
#pragma unroll
|
||||
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){
|
||||
// asm volatile (
|
||||
// "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
// "{%0, %1}, "
|
||||
// "{%2, %3}, "
|
||||
// "{%4}, "
|
||||
// "{%5, %6};"
|
||||
// : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1])
|
||||
// : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),
|
||||
// "r"(B_register[mma_k][mma_n])
|
||||
// "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
|
||||
// );
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm volatile (
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0, %1}, "
|
||||
|
|
@ -866,6 +899,19 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
"r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1])
|
||||
"r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
|
||||
);
|
||||
#else
|
||||
asm volatile (
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0, %1}, "
|
||||
"{%2, %3}, "
|
||||
"{%4}, "
|
||||
"{%5, %6};"
|
||||
: "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1])
|
||||
: "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),
|
||||
"r"(B_register[mma_k][mma_n])
|
||||
"r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
|
||||
);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -714,15 +714,15 @@ int main(void)
|
|||
|
||||
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
|
||||
// for(int i = 0; i < 26*38; i++) {
|
||||
for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
float diff = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// if(diff > 0.5) {
|
||||
printf("(%7.3f, %7.3f, %.2f, %d) \n",
|
||||
im2col_data[i], conv2d_data[i],
|
||||
diff, i);
|
||||
// break;
|
||||
// }
|
||||
}
|
||||
// for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// // if(diff > 0.5) {
|
||||
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
|
||||
// im2col_data[i], conv2d_data[i],
|
||||
// diff, i);
|
||||
// // break;
|
||||
// // }
|
||||
// }
|
||||
|
||||
ggml_free(model.ctx);
|
||||
ggml_backend_buffer_free(model.buffer);
|
||||
|
|
|
|||
Loading…
Reference in New Issue