m16n8k16 mma works; to be cleaned up

This commit is contained in:
bssrdf 2025-11-12 10:26:01 -05:00
parent fac6f0adc3
commit c33e4301dc
2 changed files with 250 additions and 81 deletions

View File

@ -343,13 +343,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k, unsigned int smem_stride>
__device__ __forceinline__ void ldmatrix_a(
const half* src,
half (&reg)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]
half (&reg)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8]
){
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4");
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8");
// static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2");
uint32_t (&reg_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast<uint32_t(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]>(reg);
// uint32_t (&reg_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast<uint32_t(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]>(reg);
uint32_t (&reg_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast<uint32_t(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]>(reg);
unsigned int logical_offset = (threadIdx.x % 32) * smem_stride;
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4);
swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2);
@ -390,7 +392,104 @@ __device__ __forceinline__ void ldmatrix_a(
src_addr ^= 0b10000;
// // 1
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1])
// : "r"(src_addr)
// );
// // 1
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1])
// : "r"(src_addr + 32 * smem_stride_)
// );
// // 1
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1])
// : "r"(src_addr + 64 * smem_stride_)
// );
// // 1
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1])
// : "r"(src_addr + 96 * smem_stride_)
// );
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[0][0][2]), "=r"(reg_[0][0][3]), "=r"(reg_[1][0][2]), "=r"(reg_[1][0][3])
: "r"(src_addr)
);
// 1
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[2][0][2]), "=r"(reg_[2][0][3]), "=r"(reg_[3][0][2]), "=r"(reg_[3][0][3])
: "r"(src_addr + 32 * smem_stride_)
);
// 1
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[4][0][2]), "=r"(reg_[4][0][3]), "=r"(reg_[5][0][2]), "=r"(reg_[5][0][3])
: "r"(src_addr + 64 * smem_stride_)
);
// 1
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[6][0][2]), "=r"(reg_[6][0][3]), "=r"(reg_[7][0][2]), "=r"(reg_[7][0][3])
: "r"(src_addr + 96 * smem_stride_)
);
src_addr ^= 0b110000;
// // 2
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1])
// : "r"(src_addr)
// );
// // 2
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1])
// : "r"(src_addr + 32 * smem_stride_)
// );
// // 2
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1])
// : "r"(src_addr + 64 * smem_stride_)
// );
// // 2
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1])
// : "r"(src_addr + 96 * smem_stride_)
// );
// 2
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
@ -398,7 +497,7 @@ __device__ __forceinline__ void ldmatrix_a(
: "r"(src_addr)
);
// 1
// 2
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
@ -406,7 +505,7 @@ __device__ __forceinline__ void ldmatrix_a(
: "r"(src_addr + 32 * smem_stride_)
);
// 1
// 2
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
@ -414,7 +513,7 @@ __device__ __forceinline__ void ldmatrix_a(
: "r"(src_addr + 64 * smem_stride_)
);
// 1
// 2
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
@ -422,46 +521,45 @@ __device__ __forceinline__ void ldmatrix_a(
: "r"(src_addr + 96 * smem_stride_)
);
src_addr ^= 0b110000;
// 2
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1])
: "r"(src_addr)
);
// 2
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1])
: "r"(src_addr + 32 * smem_stride_)
);
// 2
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1])
: "r"(src_addr + 64 * smem_stride_)
);
// 2
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1])
: "r"(src_addr + 96 * smem_stride_)
);
src_addr ^= 0b10000;
// // 3
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1])
// : "r"(src_addr)
// );
// // 3
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1])
// : "r"(src_addr + 32 * smem_stride_)
// );
// // 3
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1])
// : "r"(src_addr + 64 * smem_stride_)
// );
// // 3
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1])
// : "r"(src_addr + 96 * smem_stride_)
// );
// 3
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1])
: "=r"(reg_[0][1][2]), "=r"(reg_[0][1][3]), "=r"(reg_[1][1][2]), "=r"(reg_[1][1][3])
: "r"(src_addr)
);
@ -469,7 +567,7 @@ __device__ __forceinline__ void ldmatrix_a(
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1])
: "=r"(reg_[2][1][2]), "=r"(reg_[2][1][3]), "=r"(reg_[3][1][2]), "=r"(reg_[3][1][3])
: "r"(src_addr + 32 * smem_stride_)
);
@ -477,7 +575,7 @@ __device__ __forceinline__ void ldmatrix_a(
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1])
: "=r"(reg_[4][1][2]), "=r"(reg_[4][1][3]), "=r"(reg_[5][1][2]), "=r"(reg_[5][1][3])
: "r"(src_addr + 64 * smem_stride_)
);
@ -485,7 +583,7 @@ __device__ __forceinline__ void ldmatrix_a(
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1])
: "=r"(reg_[6][1][2]), "=r"(reg_[6][1][3]), "=r"(reg_[7][1][2]), "=r"(reg_[7][1][3])
: "r"(src_addr + 96 * smem_stride_)
);
#else
@ -498,14 +596,17 @@ __device__ __forceinline__ void ldmatrix_a(
template <unsigned int mma_tiles_per_warp_k, unsigned int mma_tiles_per_warp_n, unsigned int smem_stride>
__device__ __forceinline__ void ldmatrix_b(
const half* src,
half (&reg)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]
// half (&reg)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]
half (&reg)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4]
){
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
// static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2");
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
uint32_t (&reg_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
// uint32_t (&reg_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
uint32_t (&reg_) [2][8][2] = reinterpret_cast<uint32_t(&)[2][8][2]>(reg);
unsigned int logical_offset = (threadIdx.x % 32) * smem_stride;
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4);
swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2);
@ -513,10 +614,25 @@ __device__ __forceinline__ void ldmatrix_b(
constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
// 0
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3])
// : "r"(src_addr)
// );
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7])
// : "r"(src_addr + 32 * smem_stride_)
// );
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3])
: "=r"(reg_[0][0][0]), "=r"(reg_[0][1][0]), "=r"(reg_[0][2][0]), "=r"(reg_[0][3][0])
: "r"(src_addr)
);
@ -524,55 +640,97 @@ __device__ __forceinline__ void ldmatrix_b(
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7])
: "=r"(reg_[0][4][0]), "=r"(reg_[0][5][0]), "=r"(reg_[0][6][0]), "=r"(reg_[0][7][0])
: "r"(src_addr + 32 * smem_stride_)
);
src_addr ^= 0b10000;
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3])
// : "r"(src_addr)
// );
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7])
// : "r"(src_addr + 32 * smem_stride_)
// );
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3])
: "r"(src_addr)
);
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[0][0][1]), "=r"(reg_[0][1][1]), "=r"(reg_[0][2][1]), "=r"(reg_[0][3][1])
: "r"(src_addr)
);
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7])
: "=r"(reg_[0][4][1]), "=r"(reg_[0][5][1]), "=r"(reg_[0][6][1]), "=r"(reg_[0][7][1])
: "r"(src_addr + 32 * smem_stride_)
);
src_addr ^= 0b110000;
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3])
// : "r"(src_addr)
// );
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7])
// : "r"(src_addr + 32 * smem_stride_)
// );
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3])
: "r"(src_addr)
);
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[1][0][0]), "=r"(reg_[1][1][0]), "=r"(reg_[1][2][0]), "=r"(reg_[1][3][0])
: "r"(src_addr)
);
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7])
: "=r"(reg_[1][4][0]), "=r"(reg_[1][5][0]), "=r"(reg_[1][6][0]), "=r"(reg_[1][7][0])
: "r"(src_addr + 32 * smem_stride_)
);
src_addr ^= 0b10000;
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3])
// : "r"(src_addr)
// );
// asm volatile (
// "ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
// "{%0, %1, %2, %3}, [%4];"
// : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7])
// : "r"(src_addr + 32 * smem_stride_)
// );
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3])
: "r"(src_addr)
);
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[1][0][1]), "=r"(reg_[1][1][1]), "=r"(reg_[1][2][1]), "=r"(reg_[1][3][1])
: "r"(src_addr)
);
asm volatile (
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
"{%0, %1, %2, %3}, [%4];"
: "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7])
: "=r"(reg_[1][4][1]), "=r"(reg_[1][5][1]), "=r"(reg_[1][6][1]), "=r"(reg_[1][7][1])
: "r"(src_addr + 32 * smem_stride_)
);
#else
@ -602,7 +760,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const unsigned int NKPQ = param.n * KPQ;
// loop bounds, constexpr where possible allows for loop unrolling
constexpr unsigned int mma_tiles_per_warp_k = 4;
constexpr unsigned int mma_tiles_per_warp_k = 2;
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
const unsigned int z = blockIdx.z;
@ -629,13 +787,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// declare register storage
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2];
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4];
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2];
// convenience cast to half for register storage
half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]>(acc_register);
half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]>(A_register);
half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast<half(&)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]>(B_register);
half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8]>(A_register);
half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4]>(B_register);
// accumulators start at 0
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){
@ -685,15 +843,26 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){
#pragma unroll
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){
// asm volatile (
// "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
// "{%0, %1}, "
// "{%2, %3}, "
// "{%4}, "
// "{%5, %6};"
// : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1])
// : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),
// "r"(B_register[mma_k][mma_n])
// "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
// );
asm volatile (
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1}, "
"{%2, %3}, "
"{%4}, "
"{%5, %6};"
"{%2, %3, %4, %5}, "
"{%6, %7}, "
"{%8, %9};"
: "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1])
: "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),
"r"(B_register[mma_k][mma_n])
: "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),"r"(A_register[mma_m][mma_k][2]), "r"(A_register[mma_m][mma_k][3]),
"r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1])
"r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
);
}

View File

@ -301,7 +301,7 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
// std::make_tuple(960,320,104,152,3,3),
// std::make_tuple(1280,1280,26,38,3,3),
// std::make_tuple(1920,640,32,32,3,3)
// std::make_tuple(1280,1280,16,16,3,3),
std::make_tuple(1280,1280,16,16,3,3),
// std::make_tuple(320,640,32,32,3,3),
// std::make_tuple(4,320,96,128,3,3),
// std::make_tuple(320,4,96,128,3,3),
@ -317,7 +317,7 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
// std::make_tuple(1920,1280,26,38,3,3),
// std::make_tuple(2560,1280,26,38,3,3),
// std::make_tuple(320,1280,26,38,3,3),
std::make_tuple(512,512,104,152,3,3),
// std::make_tuple(512,512,104,152,3,3),
// std::make_tuple(512,512,208,304,3,3),
// std::make_tuple(512,256,416,608,3,3),
// std::make_tuple(256,128,832,1216,3,3),
@ -653,7 +653,7 @@ int main(void)
int k = 0;
// for (auto c : configs_sdxl_768){
// for (auto c : configs_sdxl_1024){
for (auto c : configs){
test_model model;
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),