disable m16n8k16 mma for ampere for now
This commit is contained in:
parent
0cb1ff419a
commit
b4530b4f8b
|
|
@ -7,6 +7,9 @@
|
|||
|
||||
|
||||
typedef unsigned int uint;
|
||||
|
||||
#define GGML_CUDA_CC_RUBIN 10000
|
||||
|
||||
constexpr uint WARPSIZE = 32;
|
||||
#define CUDA_NCHW_2_NHWC_TILE_DIM 32
|
||||
#define CUDA_NCHW_2_NHWC_BLOCK_NM 8
|
||||
|
|
@ -343,7 +346,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k, unsigned int smem_stride>
|
||||
__device__ __forceinline__ void ldmatrix_a(
|
||||
const half* src,
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8]
|
||||
#else
|
||||
half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]
|
||||
|
|
@ -351,13 +354,13 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8");
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2");
|
||||
#else
|
||||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast<uint32_t(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]>(reg);
|
||||
#else
|
||||
uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast<uint32_t(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]>(reg);
|
||||
|
|
@ -403,7 +406,7 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
src_addr ^= 0b10000;
|
||||
|
||||
// 1
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -471,7 +474,7 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
src_addr ^= 0b110000;
|
||||
|
||||
// 2
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -537,7 +540,7 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
src_addr ^= 0b10000;
|
||||
|
||||
// 3
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -610,7 +613,7 @@ __device__ __forceinline__ void ldmatrix_a(
|
|||
template <unsigned int mma_tiles_per_warp_k, unsigned int mma_tiles_per_warp_n, unsigned int smem_stride>
|
||||
__device__ __forceinline__ void ldmatrix_b(
|
||||
const half* src,
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4]
|
||||
#else
|
||||
half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]
|
||||
|
|
@ -618,14 +621,14 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2");
|
||||
#else
|
||||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||
#endif
|
||||
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
uint32_t (®_) [2][8][2] = reinterpret_cast<uint32_t(&)[2][8][2]>(reg);
|
||||
#else
|
||||
uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
||||
|
|
@ -637,7 +640,7 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
|
||||
|
||||
// 0
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -670,7 +673,7 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
|
||||
src_addr ^= 0b10000;
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -702,7 +705,7 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
|
||||
src_addr ^= 0b110000;
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -734,7 +737,7 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
|
||||
src_addr ^= 0b10000;
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
|
|
@ -790,7 +793,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
// const unsigned int NKPQ = param.n * KPQ;
|
||||
|
||||
// loop bounds, constexpr where possible allows for loop unrolling
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
constexpr unsigned int mma_tiles_per_warp_k = 2;
|
||||
#else
|
||||
constexpr unsigned int mma_tiles_per_warp_k = 4;
|
||||
|
|
@ -829,7 +832,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
// declare register storage
|
||||
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
|
||||
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4];
|
||||
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2];
|
||||
#else
|
||||
|
|
@ -839,7 +842,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
|
||||
// convenience cast to half for register storage
|
||||
half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]>(acc_register);
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8]>(A_register);
|
||||
half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4]>(B_register);
|
||||
#else
|
||||
|
|
@ -962,7 +965,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){
|
||||
#pragma unroll
|
||||
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
asm volatile (
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0, %1}, "
|
||||
|
|
|
|||
Loading…
Reference in New Issue