This commit is contained in:
bssrdf 2025-10-21 15:43:35 -04:00
parent 15484c9bd6
commit f0a480cc22
1 changed files with 176 additions and 0 deletions

View File

@ -730,6 +730,182 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
}
}
template<const int BM, const int BN, const int BK, const int WM, const int WN,
const int WK, const int NUM_THREADS>
static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const half * __restrict__ kernel,
float * __restrict__ output,
const param_t param) {
constexpr unsigned int MMA_M = 16;
constexpr unsigned int MMA_N = 8;
const unsigned int K = param.c * param.r * param.s;
// for convenience/readability in index calculations
const unsigned int A_stride = K;
const unsigned int B_stride = N;
const unsigned int CD_stride = N;
// calculate how many bits of shared memory indices are going to be swizzled, and create masks
constexpr unsigned int SWIZZLE_BITS_B = int_log2(BN / 8);
// loop bounds, constexpr where possible allows for loop unrolling
constexpr unsigned int mma_tiles_per_warp_k = 4;
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
const unsigned int num_block_tiles_k = K / BK;
// calculate block/warp indices
const unsigned int block_m = blockIdx.y;
const unsigned int block_n = blockIdx.x;
const unsigned int warp_m = threadIdx.y;
const unsigned int warp_n = threadIdx.x / 32;
// double buffering
extern __shared__ half shmem[];
half* A_block_smem = shmem;
half* B_block_smem = &shmem[BM * BK];
constexpr int BUFFER_SIZE = BM * BK + BK * BN;
// declare register storage
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2];
uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n];
// convenience cast to half for register storage
half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4]>(acc_register);
half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast<half(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]>(A_register);
half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast<half(&)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]>(B_register);
// accumulators start at 0
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
{
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
{
acc_register_[mma_m][mma_n][0] = 0;
acc_register_[mma_m][mma_n][1] = 0;
acc_register_[mma_m][mma_n][2] = 0;
acc_register_[mma_m][mma_n][3] = 0;
}
}
// these register arrays are used to cache values pre-fetched from global memory during the inner loop of the kernel
// the code is nicer if we hard code it for these tile dimensions and number of threads
// since we performing this copy with float4 pointers, for these tile dimensions it works out to be 8 float4s for A and 4 float4s for B
static_assert(BM_dim == 256);
static_assert(BN_dim == 256);
static_assert(BK_dim == 32);
static_assert(NUM_THREADS == 256);
float4 A_gmem_cache_reg[4];
float4 B_gmem_cache_reg[4];
// prefetch the first block tile of A,B into shared memory
half* A_block_gmem = A + (block_m * BM_dim * A_stride);
half* B_block_gmem = B + (block_n * BN_dim);
tileMemcpySwizzleA<BM_dim, NUM_THREADS>(A_block_gmem, A_block_smem, K);
tileMemcpySwizzle<BK_dim, BN_dim, NUM_THREADS, SWIZZLE_BITS_B>(B_block_gmem, B_block_smem, N);
// construct const pointers to warp tiles for use inside the inner loop
int offset_direction = 1;
for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++)
{
__syncthreads();
if (block_k != num_block_tiles_k)
{
half* A_block_gmem = A + (block_m * BM_dim * A_stride) + (block_k * BK_dim);
half* B_block_gmem = B + (block_k * BK_dim * B_stride) + (block_n * BN_dim);
tileMemcpyLoad<BM_dim, BK_dim, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, K);
tileMemcpyLoad<BK_dim, BN_dim, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, N);
}
half* A_warp_tile = A_block_smem + (warp_m * WM_dim * BK_dim);
half* B_warp_tile = B_block_smem + (warp_n * WN_dim);
ldmatrix_a<mma_tiles_per_warp_m, mma_tiles_per_warp_k, BK_dim>(A_warp_tile, A_register_);
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BN_dim>(B_warp_tile, B_register_);
// outer product between mma tiles
#pragma unroll
for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++)
{
#pragma unroll
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
{
#pragma unroll
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
{
asm volatile (
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1}, "
"{%2, %3}, "
"{%4}, "
"{%5, %6};"
: "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1])
: "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),
"r"(B_register[mma_k][mma_n])
"r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
);
}
}
}
if (block_k != num_block_tiles_k)
{
// switch smem buffers each iteration
A_block_smem = A_block_smem + BUFFER_SIZE * offset_direction;
B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction;
offset_direction = -1 * offset_direction;
tileMemcpySwizzleStoreA<BM_dim, NUM_THREADS, 4>(A_gmem_cache_reg, A_block_smem);
tileMemcpySwizzleStore<BK_dim, BN_dim, NUM_THREADS, SWIZZLE_BITS_B, 4>(B_gmem_cache_reg, B_block_smem);
}
}
//////////////
// epilogue //
//////////////
half alpha_ = (half)alpha;
half beta_ = (half)beta;
half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4];
// calculate pointers for this warps C and D tiles
half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim);
half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim);
half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim);
half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim);
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
{
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
{
half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim);
ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half));
// scale C by beta
acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_;
acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_;
acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_;
acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_;
}
}
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
{
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
{
half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim);
stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half));
}
}
}
#define NUM_VARIANTS 6
/*