disable m16n8k16 mma for ampere for now

This commit is contained in:
bssrdf 2025-11-14 12:11:52 -05:00
parent 0cb1ff419a
commit b4530b4f8b
1 changed files with 20 additions and 17 deletions

View File

@ -7,6 +7,9 @@
typedef unsigned int uint;
#define GGML_CUDA_CC_RUBIN 10000
constexpr uint WARPSIZE = 32;
#define CUDA_NCHW_2_NHWC_TILE_DIM 32
#define CUDA_NCHW_2_NHWC_BLOCK_NM 8
@ -343,7 +346,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k, unsigned int smem_stride>
__device__ __forceinline__ void ldmatrix_a(
const half* src,
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
half (&reg)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8]
#else
half (&reg)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]
@ -351,13 +354,13 @@ __device__ __forceinline__ void ldmatrix_a(
){
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8");
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2");
#else
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
#endif
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
uint32_t (&reg_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast<uint32_t(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]>(reg);
#else
uint32_t (&reg_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast<uint32_t(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]>(reg);
@ -403,7 +406,7 @@ __device__ __forceinline__ void ldmatrix_a(
src_addr ^= 0b10000;
// 1
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
@ -471,7 +474,7 @@ __device__ __forceinline__ void ldmatrix_a(
src_addr ^= 0b110000;
// 2
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
@ -537,7 +540,7 @@ __device__ __forceinline__ void ldmatrix_a(
src_addr ^= 0b10000;
// 3
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
@ -610,7 +613,7 @@ __device__ __forceinline__ void ldmatrix_a(
template <unsigned int mma_tiles_per_warp_k, unsigned int mma_tiles_per_warp_n, unsigned int smem_stride>
__device__ __forceinline__ void ldmatrix_b(
const half* src,
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
half (&reg)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4]
#else
half (&reg)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]
@ -618,14 +621,14 @@ __device__ __forceinline__ void ldmatrix_b(
){
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2");
#else
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
#endif
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
uint32_t (&reg_) [2][8][2] = reinterpret_cast<uint32_t(&)[2][8][2]>(reg);
#else
uint32_t (&reg_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
@ -637,7 +640,7 @@ __device__ __forceinline__ void ldmatrix_b(
constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
// 0
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
@ -670,7 +673,7 @@ __device__ __forceinline__ void ldmatrix_b(
src_addr ^= 0b10000;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
@ -702,7 +705,7 @@ __device__ __forceinline__ void ldmatrix_b(
src_addr ^= 0b110000;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
@ -734,7 +737,7 @@ __device__ __forceinline__ void ldmatrix_b(
src_addr ^= 0b10000;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
@ -790,7 +793,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// const unsigned int NKPQ = param.n * KPQ;
// loop bounds, constexpr where possible allows for loop unrolling
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
constexpr unsigned int mma_tiles_per_warp_k = 2;
#else
constexpr unsigned int mma_tiles_per_warp_k = 4;
@ -829,7 +832,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// declare register storage
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4];
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2];
#else
@ -839,7 +842,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// convenience cast to half for register storage
half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]>(acc_register);
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8]>(A_register);
half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4]>(B_register);
#else
@ -962,7 +965,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){
#pragma unroll
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
asm volatile (
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1}, "