WIP
This commit is contained in:
parent
215ebf6526
commit
66f6d16265
|
|
@ -892,27 +892,43 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
|
||||
|
||||
uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
||||
const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8);
|
||||
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5);
|
||||
// uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
||||
// const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8);
|
||||
// unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5);
|
||||
// uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset);
|
||||
// constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
|
||||
unsigned int logical_offset = (threadIdx.x % 32) * smem_stride;
|
||||
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4);
|
||||
swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2);
|
||||
uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset);
|
||||
constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
|
||||
|
||||
|
||||
// asm volatile (
|
||||
// "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
// "{%0, %1, %2, %3}, [%4];"
|
||||
// : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3])
|
||||
// : "r"(src_addr)
|
||||
// );
|
||||
|
||||
// 0
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7])
|
||||
: "r"(src_addr ^ 0b1000000)
|
||||
// : "r"(src_addr ^ 0b1000000)
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
src_addr += 8 * smem_stride_;
|
||||
src_addr ^= 0b10000;
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
|
|
@ -925,10 +941,12 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7])
|
||||
: "r"(src_addr ^ 0b1000000)
|
||||
// : "r"(src_addr ^ 0b1000000)
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
src_addr += 8 * smem_stride_;
|
||||
// src_addr += 8 * smem_stride_;
|
||||
src_addr ^= 0b110000;
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
|
|
@ -941,10 +959,11 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7])
|
||||
: "r"(src_addr ^ 0b1000000)
|
||||
// : "r"(src_addr ^ 0b1000000)
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
src_addr += 8 * smem_stride_;
|
||||
src_addr ^= 0b10000;
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
|
|
@ -957,7 +976,8 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7])
|
||||
: "r"(src_addr ^ 0b1000000)
|
||||
// : "r"(src_addr ^ 0b1000000)
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
}
|
||||
|
|
@ -1038,7 +1058,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
// prefetch the first block tile of A,B into shared memory
|
||||
// half* A_block_gmem = input + (block_m * BM * A_stride);
|
||||
half* A_block_gmem = input;
|
||||
half* B_block_gmem = weight + (block_n * weightKOffset);
|
||||
half* B_block_gmem = kernel + (block_n * weightKOffset);
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inChannelOffset, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
|
||||
|
||||
|
|
@ -1053,16 +1073,17 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
|
||||
if (block_k != num_block_tiles_k)
|
||||
{
|
||||
half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK);
|
||||
half* B_block_gmem = B + (block_k * BK * B_stride) + (block_n * BN);
|
||||
tileMemcpyLoad<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, K);
|
||||
tileMemcpyLoad<BK, BN, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, N);
|
||||
// half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK);
|
||||
half* A_block_gmem = input;
|
||||
half* B_block_gmem = kernel + (block_n * weightKOffset);
|
||||
tileMemcpyLoad<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param);
|
||||
tileMemcpyLoad<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
|
||||
}
|
||||
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
|
||||
half* B_warp_tile = B_block_smem + (warp_n * WN);
|
||||
half* B_warp_tile = B_block_smem + (warp_n * WN * BK);
|
||||
|
||||
ldmatrix_a<mma_tiles_per_warp_m, mma_tiles_per_warp_k, BK>(A_warp_tile, A_register_);
|
||||
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BN>(B_warp_tile, B_register_);
|
||||
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BK>(B_warp_tile, B_register_);
|
||||
|
||||
// outer product between mma tiles
|
||||
#pragma unroll
|
||||
|
|
@ -1097,47 +1118,47 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction;
|
||||
offset_direction = -1 * offset_direction;
|
||||
|
||||
tileMemcpySwizzleStoreA<BM, NUM_THREADS, 4>(A_gmem_cache_reg, A_block_smem);
|
||||
tileMemcpySwizzleStoreB<BN, NUM_THREADS> (B_gmem_cache_reg, B_block_smem);
|
||||
tileMemcpySwizzleStore<BM, NUM_THREADS, 4>(A_gmem_cache_reg, A_block_smem);
|
||||
tileMemcpySwizzleStore<BN, NUM_THREADS, 4>(B_gmem_cache_reg, B_block_smem);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////
|
||||
// epilogue //
|
||||
//////////////
|
||||
half alpha_ = (half)alpha;
|
||||
half beta_ = (half)beta;
|
||||
half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4];
|
||||
// half alpha_ = (half)alpha;
|
||||
// half beta_ = (half)beta;
|
||||
// half C_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][4];
|
||||
|
||||
// calculate pointers for this warps C and D tiles
|
||||
half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim);
|
||||
half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim);
|
||||
half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim);
|
||||
half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim);
|
||||
// // calculate pointers for this warps C and D tiles
|
||||
// half* C_block_gmem = C + (block_m * BM_dim * CD_stride) + (block_n * BN_dim);
|
||||
// half* C_warp_gmem = C_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim);
|
||||
// half* D_block_gmem = D + (block_m * BM_dim * CD_stride) + (block_n * BN_dim);
|
||||
// half* D_warp_gmem = D_block_gmem + (warp_m * WM_dim * CD_stride) + (warp_n * WN_dim);
|
||||
|
||||
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
|
||||
{
|
||||
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
|
||||
{
|
||||
half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim);
|
||||
ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half));
|
||||
// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
|
||||
// {
|
||||
// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
|
||||
// {
|
||||
// half* C_mma_tile = C_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim);
|
||||
// ldmatrix_m16n8_gmem(C_mma_tile, C_register[mma_m][mma_n], N * sizeof(half));
|
||||
|
||||
// scale C by beta
|
||||
acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_;
|
||||
acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_;
|
||||
acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_;
|
||||
acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_;
|
||||
}
|
||||
}
|
||||
// // scale C by beta
|
||||
// acc_register_[mma_m][mma_n][0] = acc_register_[mma_m][mma_n][0] * alpha_ + C_register[mma_m][mma_n][0] * beta_;
|
||||
// acc_register_[mma_m][mma_n][1] = acc_register_[mma_m][mma_n][1] * alpha_ + C_register[mma_m][mma_n][1] * beta_;
|
||||
// acc_register_[mma_m][mma_n][2] = acc_register_[mma_m][mma_n][2] * alpha_ + C_register[mma_m][mma_n][2] * beta_;
|
||||
// acc_register_[mma_m][mma_n][3] = acc_register_[mma_m][mma_n][3] * alpha_ + C_register[mma_m][mma_n][3] * beta_;
|
||||
// }
|
||||
// }
|
||||
|
||||
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
|
||||
{
|
||||
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
|
||||
{
|
||||
half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim);
|
||||
stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half));
|
||||
}
|
||||
}
|
||||
// for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
|
||||
// {
|
||||
// for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++)
|
||||
// {
|
||||
// half* D_mma_tile = D_warp_gmem + (mma_m * MMA_M_dim * CD_stride) + (mma_n * MMA_N_dim);
|
||||
// stmatrix_m16n8(D_mma_tile, acc_register_[mma_m][mma_n], N * sizeof(half));
|
||||
// }
|
||||
// }
|
||||
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -190,15 +190,176 @@ template<unsigned int TILE_ROWS,
|
|||
unsigned int TILE_COLS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ void tileMemcpyLoad(
|
||||
__device__ __forceinline__ void tileMemcpyLoadA(
|
||||
half* src,
|
||||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
const unsigned int src_stride
|
||||
// const unsigned int src_stride,
|
||||
const unsigned int block_k,
|
||||
const unsigned int inChannelOffset,
|
||||
param_t param
|
||||
)
|
||||
{
|
||||
// reinterpret input/output as float4
|
||||
float4* src_float4 = reinterpret_cast<float4*>(src);
|
||||
const unsigned int src_stride_vectorized = src_stride / 8;
|
||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
||||
{
|
||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
||||
// dst_reg[i] = src_float4[src_index];
|
||||
// thread_row += ROW_STEP;
|
||||
unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row;
|
||||
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
// TODO: next block_k loop
|
||||
const uint curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const uint curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const uint curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
int curH = posh_ori + curR * param.d_h; // input h
|
||||
int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
curR < param.R && curS < param.S && curC < param.c){
|
||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||
} else{
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int TILE_COLS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ void tileMemcpyLoadB(
|
||||
half* src,
|
||||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
const unsigned int block_k,
|
||||
const unsigned int src_stride,
|
||||
param_t param
|
||||
)
|
||||
{
|
||||
// reinterpret input/output as float4
|
||||
float4* src_float4 = reinterpret_cast<float4*>(src);
|
||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const uint curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const uint curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const uint curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
||||
{
|
||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
||||
// dst_reg[i] = src_float4[src_index];
|
||||
// thread_row += ROW_STEP;
|
||||
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
|
||||
if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
}
|
||||
|
||||
// template<unsigned int TILE_ROWS,
|
||||
// unsigned int TILE_COLS,
|
||||
// unsigned int NUM_THREADS,
|
||||
// unsigned int SWIZZLE_BITS,
|
||||
// unsigned int ELEMENTS_PER_THREAD>
|
||||
// __device__ __forceinline__ void tileMemcpySwizzleStoreB(
|
||||
// float4 src_reg[ELEMENTS_PER_THREAD],
|
||||
// half* dst
|
||||
// )
|
||||
// {
|
||||
// constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS;
|
||||
|
||||
// // reinterpret input/output as float4
|
||||
// float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
|
||||
// // # of threads is multiple of # of columns in the tile
|
||||
// constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
// static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// // flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||
// const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// // assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// // to cover the whole tile
|
||||
// constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
// constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
// unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
// const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
// // compile time check that we provided the right amount of registers for storage
|
||||
// static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
// #pragma unroll
|
||||
// for (unsigned int i = 0; i < NUM_ITERS; i++)
|
||||
// {
|
||||
// // apply swizzle to the dst index
|
||||
// unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS);
|
||||
// dst_float4[dst_index] = src_reg[i];
|
||||
// thread_row += ROW_STEP;
|
||||
// }
|
||||
// }
|
||||
|
||||
// same as above but without the swizzle
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int TILE_COLS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ void tileMemcpyStore(
|
||||
float4 src_reg[ELEMENTS_PER_THREAD],
|
||||
half* dst,
|
||||
unsigned int dst_stride_float4
|
||||
)
|
||||
{
|
||||
// reinterpret input/output as float4
|
||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
|
|
@ -220,11 +381,72 @@ __device__ __forceinline__ void tileMemcpyLoad(
|
|||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
||||
{
|
||||
const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
||||
dst_reg[i] = src_float4[src_index];
|
||||
// apply swizzle to the dst index
|
||||
unsigned int dst_index = thread_row * dst_stride_float4 + thread_col;
|
||||
dst_float4[dst_index] = src_reg[i];
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
}
|
||||
|
||||
// this is a special case of the above for when TILE_COLS == 32
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ void tileMemcpySwizzleStore(
|
||||
const float4 (&src_reg)[ELEMENTS_PER_THREAD],
|
||||
half* dst
|
||||
)
|
||||
{
|
||||
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
||||
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
||||
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
|
||||
constexpr unsigned int SWIZZLE_BITS_2 = 2;
|
||||
constexpr unsigned int TILE_COLS = 32;
|
||||
|
||||
// reinterpret input/output as float4
|
||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
||||
{
|
||||
// apply swizzle to the dst index
|
||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
dst_float4[dst_index] = src_reg[i];
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
|
||||
uint32_t address;
|
||||
asm("{\n\t"
|
||||
" .reg .u64 u64addr;\n\t"
|
||||
" cvta.to.shared.u64 u64addr, %1;\n\t"
|
||||
" cvt.u32.u64 %0, u64addr;\n\t"
|
||||
"}"
|
||||
: "=r"(address)
|
||||
: "l"(pointer));
|
||||
return address;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256
|
||||
|
|
|
|||
Loading…
Reference in New Issue