WIP
This commit is contained in:
parent
1b69ed44c6
commit
215ebf6526
|
|
@ -732,6 +732,236 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
||||
template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k, unsigned int smem_stride>
|
||||
__device__ __forceinline__ void ldmatrix_a(
|
||||
const half* src,
|
||||
half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]
|
||||
)
|
||||
{
|
||||
static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4");
|
||||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||
|
||||
uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast<uint32_t(&)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]>(reg);
|
||||
unsigned int logical_offset = (threadIdx.x % 32) * smem_stride;
|
||||
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4);
|
||||
swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2);
|
||||
uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset);
|
||||
constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
|
||||
|
||||
// 0
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][0][0]), "=r"(reg_[0][0][1]), "=r"(reg_[1][0][0]), "=r"(reg_[1][0][1])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
// 0
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][0][0]), "=r"(reg_[2][0][1]), "=r"(reg_[3][0][0]), "=r"(reg_[3][0][1])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
// 0
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[4][0][0]), "=r"(reg_[4][0][1]), "=r"(reg_[5][0][0]), "=r"(reg_[5][0][1])
|
||||
: "r"(src_addr + 64 * smem_stride_)
|
||||
);
|
||||
|
||||
// 0
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[6][0][0]), "=r"(reg_[6][0][1]), "=r"(reg_[7][0][0]), "=r"(reg_[7][0][1])
|
||||
: "r"(src_addr + 96 * smem_stride_)
|
||||
);
|
||||
|
||||
src_addr ^= 0b10000;
|
||||
|
||||
// 1
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
// 1
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
// 1
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1])
|
||||
: "r"(src_addr + 64 * smem_stride_)
|
||||
);
|
||||
|
||||
// 1
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1])
|
||||
: "r"(src_addr + 96 * smem_stride_)
|
||||
);
|
||||
|
||||
src_addr ^= 0b110000;
|
||||
|
||||
// 2
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
// 2
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
// 2
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1])
|
||||
: "r"(src_addr + 64 * smem_stride_)
|
||||
);
|
||||
|
||||
// 2
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1])
|
||||
: "r"(src_addr + 96 * smem_stride_)
|
||||
);
|
||||
src_addr ^= 0b10000;
|
||||
|
||||
// 3
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
// 3
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1])
|
||||
: "r"(src_addr + 32 * smem_stride_)
|
||||
);
|
||||
|
||||
// 3
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1])
|
||||
: "r"(src_addr + 64 * smem_stride_)
|
||||
);
|
||||
|
||||
// 3
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1])
|
||||
: "r"(src_addr + 96 * smem_stride_)
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
template <unsigned int mma_tiles_per_warp_k, unsigned int mma_tiles_per_warp_n, unsigned int smem_stride>
|
||||
__device__ __forceinline__ void ldmatrix_b(
|
||||
const half* src,
|
||||
half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]
|
||||
)
|
||||
{
|
||||
static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4");
|
||||
static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8");
|
||||
|
||||
uint32_t (®_) [4][8] = reinterpret_cast<uint32_t(&)[4][8]>(reg);
|
||||
const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8);
|
||||
unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5);
|
||||
uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset);
|
||||
constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7])
|
||||
: "r"(src_addr ^ 0b1000000)
|
||||
);
|
||||
|
||||
src_addr += 8 * smem_stride_;
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7])
|
||||
: "r"(src_addr ^ 0b1000000)
|
||||
);
|
||||
|
||||
src_addr += 8 * smem_stride_;
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7])
|
||||
: "r"(src_addr ^ 0b1000000)
|
||||
);
|
||||
|
||||
src_addr += 8 * smem_stride_;
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3])
|
||||
: "r"(src_addr)
|
||||
);
|
||||
|
||||
asm volatile (
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 "
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7])
|
||||
: "r"(src_addr ^ 0b1000000)
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
template<const int BM, const int BN, const int BK, const int WM, const int WN,
|
||||
const int WK, const int NUM_THREADS>
|
||||
static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||
|
|
@ -742,6 +972,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
constexpr unsigned int MMA_N = 8;
|
||||
|
||||
const unsigned int K = param.c * param.r * param.s;
|
||||
const uint PQ = param.Oh * param.Ow;
|
||||
const uint inChannelOffset = param.c * param.w;
|
||||
const uint weightKOffset = param.c * param.r * param.s;
|
||||
|
||||
// for convenience/readability in index calculations
|
||||
const unsigned int A_stride = K;
|
||||
|
|
@ -801,12 +1034,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
static_assert(NUM_THREADS == 256);
|
||||
float4 A_gmem_cache_reg[4];
|
||||
float4 B_gmem_cache_reg[4];
|
||||
|
||||
|
||||
// prefetch the first block tile of A,B into shared memory
|
||||
half* A_block_gmem = input + (block_m * BM * A_stride);
|
||||
half* B_block_gmem = weight + (block_n * BN);
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, K);
|
||||
tileMemcpySwizzle<BK, BN, NUM_THREADS, SWIZZLE_BITS_B>(B_block_gmem, B_block_smem, N);
|
||||
// half* A_block_gmem = input + (block_m * BM * A_stride);
|
||||
half* A_block_gmem = input;
|
||||
half* B_block_gmem = weight + (block_n * weightKOffset);
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inChannelOffset, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
|
||||
|
||||
// construct const pointers to warp tiles for use inside the inner loop
|
||||
|
||||
|
|
@ -864,7 +1098,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
offset_direction = -1 * offset_direction;
|
||||
|
||||
tileMemcpySwizzleStoreA<BM, NUM_THREADS, 4>(A_gmem_cache_reg, A_block_smem);
|
||||
tileMemcpySwizzleStore<BK, BN, NUM_THREADS, SWIZZLE_BITS_B, 4>(B_gmem_cache_reg, B_block_smem);
|
||||
tileMemcpySwizzleStoreB<BN, NUM_THREADS> (B_gmem_cache_reg, B_block_smem);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1026,6 +1260,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW };
|
||||
params.SC_fastdiv = init_fastdiv_values(KW*IC);
|
||||
params.OW_fastdiv = init_fastdiv_values(OW);
|
||||
params.OHOW_fastdiv = init_fastdiv_values(OW*OH);
|
||||
params.C_fastdiv = init_fastdiv_values(IC);
|
||||
params.RS_fastdiv = init_fastdiv_values(KW*KH);
|
||||
params.S_fastdiv = init_fastdiv_values(KW);
|
||||
|
|
|
|||
|
|
@ -23,31 +23,70 @@ typedef struct{
|
|||
uint3 C_fastdiv;
|
||||
uint3 RS_fastdiv;
|
||||
uint3 S_fastdiv;
|
||||
uint3 OHOW_fastdiv;
|
||||
} param_t;
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int TILE_COLS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int SWIZZLE_BITS>
|
||||
__device__ __forceinline__ void tileMemcpySwizzle(
|
||||
unsigned int NUM_THREADS>
|
||||
__device__ __forceinline__ void tileMemcpySwizzleB(
|
||||
half* src,
|
||||
half* dst,
|
||||
const unsigned int src_stride
|
||||
)
|
||||
{
|
||||
constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS;
|
||||
// constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS;
|
||||
|
||||
// // reinterpret input/output as float4
|
||||
// float4* src_float4 = reinterpret_cast<float4*>(src);
|
||||
// float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
||||
|
||||
// // # of threads is multiple of # of columns in the tile
|
||||
// constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
// static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// // flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||
// const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// // assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// // to cover the whole tile
|
||||
// constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
// constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
// unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
// const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
// #pragma unroll
|
||||
// for (unsigned int i = 0; i < NUM_ITERS; i++)
|
||||
// {
|
||||
// // apply swizzle to the dst index
|
||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
||||
// unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS);
|
||||
// if (thread_col * 8 < param.k && start_k + innerColA * 4 < end_k){
|
||||
// float4 tmp = reinterpret_cast<const float4 *>(&src[thread_row * src_stride_vectorized + thread_col*8)[0];
|
||||
// dst_float4[dst_index] = src_float4[src_index];
|
||||
// }else{ // read 4 halves
|
||||
// dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
// }
|
||||
// thread_row += ROW_STEP;
|
||||
// }
|
||||
|
||||
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
||||
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
||||
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
|
||||
constexpr unsigned int SWIZZLE_BITS_2 = 2;
|
||||
constexpr unsigned int TILE_COLS = 32;
|
||||
|
||||
// reinterpret input/output as float4
|
||||
float4* src_float4 = reinterpret_cast<float4*>(src);
|
||||
// float4* src_float4 = reinterpret_cast<float4*>(src);
|
||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
const unsigned int src_stride_vectorized = src_stride / 8;
|
||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
|
|
@ -57,15 +96,24 @@ __device__ __forceinline__ void tileMemcpySwizzle(
|
|||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
// TODO: next block_k loop
|
||||
const uint curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const uint curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const uint curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
||||
{
|
||||
// apply swizzle to the dst index
|
||||
const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
||||
const unsigned int src_index = thread_row * src_stride + thread_col * 8;
|
||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS);
|
||||
dst_float4[dst_index] = src_float4[src_index];
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
}
|
||||
|
|
@ -77,7 +125,9 @@ unsigned int NUM_THREADS>
|
|||
__device__ __forceinline__ void tileMemcpySwizzleA(
|
||||
half* src,
|
||||
half* dst,
|
||||
const unsigned int src_stride
|
||||
// const unsigned int src_stride,
|
||||
const unsigned int inChannelOffset,
|
||||
param_t param
|
||||
)
|
||||
{
|
||||
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
||||
|
|
@ -87,14 +137,13 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
constexpr unsigned int TILE_COLS = 32;
|
||||
|
||||
// reinterpret input/output as float4
|
||||
float4* src_float4 = reinterpret_cast<float4*>(src);
|
||||
// float4* src_float4 = reinterpret_cast<float4*>(src);
|
||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
const unsigned int src_stride_vectorized = src_stride / 8;
|
||||
// const unsigned int src_stride_vectorized = src_stride / 8;
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
|
|
@ -104,16 +153,35 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++)
|
||||
{
|
||||
unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row;
|
||||
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
// TODO: next block_k loop
|
||||
const uint curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const uint curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const uint curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
int curH = posh_ori + curR * param.d_h; // input h
|
||||
int curW = posw_ori + curS * param.d_w; // input w
|
||||
// apply swizzle to the dst index
|
||||
const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
dst_float4[dst_index] = src_float4[src_index];
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
curR < param.R && curS < param.S && curC < param.c){
|
||||
// const unsigned int src_index = thread_row * src_stride_vectorized + thread_col;
|
||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||
} else{
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue