conv3d WIP: turn on tensor cores; NCDHW2NDHWC to be worked out
This commit is contained in:
parent
a5b68bcea7
commit
3f5c5045da
|
|
@ -62,28 +62,6 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
|
|||
}
|
||||
}
|
||||
|
||||
template<const int layout>
|
||||
__device__ int4 inputIndices(const uint kidx, param_t param) {
|
||||
|
||||
const uint cur0 = fastdiv(kidx,
|
||||
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
|
||||
const uint cur0_res = fastmodulo(kidx,
|
||||
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
|
||||
const uint cur1 = fastdiv(cur0_res,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
|
||||
const uint cur1_res = fastmodulo(cur0_res,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
|
||||
const uint cur2 = fastdiv(cur1_res,
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint cur3 = fastmodulo(cur1_res,
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint curC = layout == 0 ? cur3 : cur0;
|
||||
const uint curT = layout == 0 ? cur0 : cur1;
|
||||
const uint curR = layout == 0 ? cur1 : cur2;
|
||||
const uint curS = layout == 0 ? cur2 : cur3;
|
||||
return make_int4(curC, curT, curR, curS);
|
||||
|
||||
}
|
||||
|
||||
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
|
||||
const int WNITER, const int TM, const int TN, const int NUM_THREADS,
|
||||
|
|
@ -553,7 +531,6 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
|
||||
template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k, unsigned int smem_stride>
|
||||
__device__ __forceinline__ void ldmatrix_a(
|
||||
|
|
@ -805,13 +782,16 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
|
|||
const param_t param) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
|
||||
constexpr unsigned int MMA_M = 16;
|
||||
constexpr unsigned int MMA_N = 8;
|
||||
constexpr unsigned int MMA_M = 16;
|
||||
constexpr unsigned int MMA_N = 8;
|
||||
|
||||
const uint PQZ = param.Oh * param.Ow * param.Od;
|
||||
|
||||
const unsigned int K = param.c * param.r * param.s;
|
||||
const unsigned int K = param.c * param.r * param.s * param.t;
|
||||
const uint weightKOffset = K; //param.c * param.r * param.s * param.t;
|
||||
const uint inChannelOffset = param.c * param.w;
|
||||
const uint weightKOffset = param.c * param.r * param.s;
|
||||
const uint inDepthOffset = param.h * param.c * param.w;
|
||||
const uint inNOffset = param.c * param.w * param.h * param.d;
|
||||
|
||||
// loop bounds, constexpr where possible allows for loop unrolling
|
||||
constexpr unsigned int mma_tiles_per_warp_k = 4;
|
||||
|
|
@ -863,7 +843,7 @@ constexpr unsigned int MMA_N = 8;
|
|||
|
||||
const half* A_block_gmem = input;
|
||||
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inChannelOffset, param);
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inNOffset, inDepthOffset, inChannelOffset, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
|
||||
|
||||
int offset_direction = 1;
|
||||
|
|
@ -874,7 +854,8 @@ constexpr unsigned int MMA_N = 8;
|
|||
if (block_k != num_block_tiles_k){
|
||||
const half* A_block_gmem = input;
|
||||
const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
|
||||
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param);
|
||||
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK,
|
||||
inNOffset, inDepthOffset, inChannelOffset, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
|
||||
}
|
||||
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
|
||||
|
|
@ -954,10 +935,13 @@ constexpr unsigned int MMA_N = 8;
|
|||
for (int j = 0; j < 4; ++j){
|
||||
const uint row = m_idx + subk + i * WN / 2;
|
||||
const uint gemm_i = n_idx + j*32;
|
||||
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
if(n < param.n && row < param.k && col < param.Oh * param.Ow){
|
||||
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||
// const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
// const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
const int n = fastdiv(gemm_i, param.PQZ_fastdiv);
|
||||
const int col = fastmodulo(gemm_i, param.PQZ_fastdiv);
|
||||
if(n < param.n && row < param.k && col < PQZ){
|
||||
// const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||
const uint outOffset = (n * param.k + row) * PQZ + col;
|
||||
uint idx = output_lds_addr + subk + j*32*BN/2;
|
||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||
output[outOffset] = smemoutput[idx];
|
||||
|
|
@ -974,8 +958,6 @@ constexpr unsigned int MMA_N = 8;
|
|||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#define NUM_VARIANTS 4
|
||||
|
||||
/*
|
||||
|
|
@ -1021,64 +1003,64 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
|
|||
|
||||
static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
|
||||
|
||||
// if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) {
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) {
|
||||
|
||||
// int id = ggml_cuda_get_device();
|
||||
int id = ggml_cuda_get_device();
|
||||
|
||||
// int64_t ne = P.c * P.h * P.w * P.n;
|
||||
// int64_t ne00 = P.c;
|
||||
// int64_t ne01 = P.h * P.w;
|
||||
// ggml_cuda_pool_alloc<half> input_f16(ctx.pool(id), ne);
|
||||
int64_t ne = P.c * P.h * P.w * P.n;
|
||||
int64_t ne00 = P.c;
|
||||
int64_t ne01 = P.h * P.w;
|
||||
ggml_cuda_pool_alloc<half> input_f16(ctx.pool(id), ne);
|
||||
|
||||
// dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
// (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
// (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
|
||||
// dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1);
|
||||
// NCHW2NHWC<float, half><<<dimGrid, dimBlock, 0, st>>>(X_D, input_f16.get(), ne, ne00, ne01);
|
||||
dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
|
||||
dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1);
|
||||
NCHW2NHWC<float, half><<<dimGrid, dimBlock, 0, st>>>(X_D, input_f16.get(), ne, ne00, ne01);
|
||||
|
||||
// ne = P.c * P.r * P.s * P.k;
|
||||
// ne01 = P.r * P.s;
|
||||
// ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
|
||||
// dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
// (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
// (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
|
||||
// NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
|
||||
ne = P.c * P.r * P.s * P.k;
|
||||
ne01 = P.r * P.s;
|
||||
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
|
||||
dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
|
||||
NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
|
||||
|
||||
// const half *X_H = input_f16.get();
|
||||
// const half *K_H = kernel_f16.get();
|
||||
// ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
|
||||
const half *X_H = input_f16.get();
|
||||
const half *K_H = kernel_f16.get();
|
||||
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Od *P.Oh * P.Ow * P.n);
|
||||
|
||||
// constexpr unsigned int BM_dim = 256;
|
||||
// constexpr unsigned int BN_dim = 256;
|
||||
// constexpr unsigned int BK_dim = 32;
|
||||
constexpr unsigned int BM_dim = 256;
|
||||
constexpr unsigned int BN_dim = 256;
|
||||
constexpr unsigned int BK_dim = 32;
|
||||
|
||||
// constexpr unsigned int WARPS_PER_BLOCK_M = 2;
|
||||
// constexpr unsigned int WARPS_PER_BLOCK_N = 4;
|
||||
// constexpr unsigned int WARPS_PER_BLOCK_K = 4;
|
||||
constexpr unsigned int WARPS_PER_BLOCK_M = 2;
|
||||
constexpr unsigned int WARPS_PER_BLOCK_N = 4;
|
||||
constexpr unsigned int WARPS_PER_BLOCK_K = 4;
|
||||
|
||||
// constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
|
||||
// constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
|
||||
// constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
|
||||
// const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim;
|
||||
// const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
|
||||
// constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
|
||||
// constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N;
|
||||
// constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
|
||||
// const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
||||
constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
|
||||
constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
|
||||
constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
|
||||
const unsigned int BlocksM = (P.n * P.Oh * P.Ow * P.Od + BM_dim - 1) / BM_dim;
|
||||
const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
|
||||
constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
|
||||
constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N;
|
||||
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
|
||||
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
||||
|
||||
// cudaFuncSetAttribute(conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
|
||||
// cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
|
||||
// dim3 gridDim(BlocksN, BlocksM);
|
||||
// dim3 blockDim(ThreadsN, ThreadsM);
|
||||
cudaFuncSetAttribute(conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
|
||||
dim3 gridDim(BlocksN, BlocksM);
|
||||
dim3 blockDim(ThreadsN, ThreadsM);
|
||||
|
||||
// conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim,
|
||||
// WM_dim, WN_dim, WK_dim, NumThreads>
|
||||
// <<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
|
||||
// const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
// to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
|
||||
// } else{
|
||||
conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim,
|
||||
WM_dim, WN_dim, WK_dim, NumThreads>
|
||||
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow *P.Od * P.n, st);
|
||||
} else{
|
||||
conv3d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
|
||||
// }
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,6 +35,29 @@ typedef struct{
|
|||
} param_t;
|
||||
|
||||
|
||||
template<const int layout>
|
||||
__device__ __forceinline__ int4 inputIndices(const unsigned int kidx, param_t param) {
|
||||
|
||||
const unsigned int cur0 = fastdiv(kidx,
|
||||
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
|
||||
const unsigned int cur0_res = fastmodulo(kidx,
|
||||
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
|
||||
const unsigned int cur1 = fastdiv(cur0_res,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
|
||||
const unsigned int cur1_res = fastmodulo(cur0_res,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
|
||||
const unsigned int cur2 = fastdiv(cur1_res,
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const unsigned int cur3 = fastmodulo(cur1_res,
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const unsigned int curC = layout == 0 ? cur3 : cur0;
|
||||
const unsigned int curT = layout == 0 ? cur0 : cur1;
|
||||
const unsigned int curR = layout == 0 ? cur1 : cur2;
|
||||
const unsigned int curS = layout == 0 ? cur2 : cur3;
|
||||
return make_int4(curC, curT, curR, curS);
|
||||
|
||||
}
|
||||
|
||||
|
||||
// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel
|
||||
template<unsigned int TILE_ROWS,
|
||||
|
|
@ -67,18 +90,24 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
|
||||
// const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
||||
// const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
const unsigned int kidx = thread_col*8;
|
||||
const int4 curIdx = inputIndices<0>(kidx, param);
|
||||
const int curC = curIdx.x;
|
||||
const int curT = curIdx.y;
|
||||
const int curR = curIdx.z;
|
||||
const int curS = curIdx.w;
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
// apply swizzle to the dst index
|
||||
const unsigned int src_index = thread_row * src_stride + thread_col * 8;
|
||||
const unsigned int src_index = thread_row * src_stride + kidx;
|
||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
|
||||
// TODO: move some checks outside of loop?
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
@ -102,6 +131,8 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
const half* src,
|
||||
half* dst,
|
||||
// const unsigned int src_stride,
|
||||
const unsigned int inNOffset,
|
||||
const unsigned int inDepthOffset,
|
||||
const unsigned int inChannelOffset,
|
||||
param_t param
|
||||
)
|
||||
|
|
@ -129,28 +160,43 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
const unsigned int kidx = thread_col*8;
|
||||
const int4 curIdx = inputIndices<0>(kidx, param);
|
||||
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
||||
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
int curH = posh_ori + curR * param.dilation1; // input h
|
||||
int curW = posw_ori + curS * param.dilation0; // input w
|
||||
unsigned int n = fastdiv(gemm_i, param.PQZ_fastdiv);
|
||||
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
|
||||
const int posd_ori = fastdiv(npqz_res, param.OHOW_fastdiv) * param.stride2 - param.padding2;
|
||||
const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv);
|
||||
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||
// unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
// unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
// int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||
// int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||
// unsigned int inOffset = n * inNOffset;
|
||||
// const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
||||
// const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// int curH = posh_ori + curR * param.dilation1; // input h
|
||||
// int curW = posw_ori + curS * param.dilation0; // input w
|
||||
|
||||
const int curD = posd_ori + curIdx.y * param.dilation2; // input d
|
||||
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
|
||||
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
|
||||
const int curC = curIdx.x;
|
||||
// apply swizzle to the dst index
|
||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
curR < param.r && curS < param.s && curC < param.c){
|
||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){
|
||||
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
|
||||
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
// curR < param.r && curS < param.s && curC < param.c){
|
||||
// const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
|
||||
} else{
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
|
|
@ -174,6 +220,8 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
// const unsigned int src_stride,
|
||||
const unsigned int block_k,
|
||||
const unsigned int inNOffset,
|
||||
const unsigned int inDepthOffset,
|
||||
const unsigned int inChannelOffset,
|
||||
param_t param
|
||||
){
|
||||
|
|
@ -196,23 +244,38 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const unsigned int kidx = block_k + thread_col*8;
|
||||
const int4 curIdx = inputIndices<0>(kidx, param);
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
||||
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
int curH = posh_ori + curR * param.dilation1; // input h
|
||||
int curW = posw_ori + curS * param.dilation0; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
curR < param.r && curS < param.s && curC < param.c){
|
||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||
// unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
// unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
// int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||
// int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||
unsigned int n = fastdiv(gemm_i, param.PQZ_fastdiv);
|
||||
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
|
||||
const int posd_ori = fastdiv(npqz_res, param.OHOW_fastdiv) * param.stride2 - param.padding2;
|
||||
const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv);
|
||||
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||
// unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
// const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
// const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// int curH = posh_ori + curR * param.dilation1; // input h
|
||||
// int curW = posw_ori + curS * param.dilation0; // input w
|
||||
const int curD = posd_ori + curIdx.y * param.dilation2; // input d
|
||||
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
|
||||
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
|
||||
const int curC = curIdx.x;
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){
|
||||
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
|
||||
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
// curR < param.r && curS < param.s && curC < param.c){
|
||||
// const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
|
||||
} else{
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
|
|
@ -259,14 +322,21 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
|
||||
// const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
// const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
const unsigned int kidx = block_k + thread_col*8;
|
||||
const int4 curIdx = inputIndices<0>(kidx, param);
|
||||
const int curC = curIdx.x;
|
||||
const int curT = curIdx.y;
|
||||
const int curR = curIdx.z;
|
||||
const int curS = curIdx.w;
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
|
||||
// if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
|
||||
// TODO : move some checks outside of the loop
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
|
|||
|
|
@ -323,13 +323,13 @@ int main(void)
|
|||
// std::make_tuple(960,320,104,152,3,3),
|
||||
// std::make_tuple(1280,1280,26,38,3,3),
|
||||
std::make_tuple(320,1280,26,38,8,3,3,3),
|
||||
std::make_tuple(1280,1280,26,38,8,3,3,3),
|
||||
std::make_tuple(320,1280,52,76,8,3,3,3),
|
||||
std::make_tuple(1280,1280,52,76,8,3,3,3),
|
||||
std::make_tuple(320,1280,104,152,8,3,3,3),
|
||||
std::make_tuple(1280,1280,104,152,8,3,3,3),
|
||||
std::make_tuple(320,1280,208,304,4,3,3,3),
|
||||
std::make_tuple(640,1280,208,304,4,3,3,3),
|
||||
// std::make_tuple(1280,1280,26,38,8,3,3,3),
|
||||
// std::make_tuple(320,1280,52,76,8,3,3,3),
|
||||
// std::make_tuple(1280,1280,52,76,8,3,3,3),
|
||||
// std::make_tuple(320,1280,104,152,8,3,3,3),
|
||||
// std::make_tuple(1280,1280,104,152,8,3,3,3),
|
||||
// std::make_tuple(320,1280,208,304,4,3,3,3),
|
||||
// std::make_tuple(640,1280,208,304,4,3,3,3),
|
||||
// std::make_tuple(1280,1280,26,38,1,1),
|
||||
// std::make_tuple(256,128,768,1024,3,3),
|
||||
// std::make_tuple(128,3,768,1024,3,3),
|
||||
|
|
@ -367,7 +367,7 @@ int main(void)
|
|||
|
||||
|
||||
struct ggml_cgraph * gf_res_0 = NULL;
|
||||
int iterations = 20;
|
||||
int iterations = 0;
|
||||
|
||||
double run_time0;
|
||||
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations,
|
||||
|
|
|
|||
Loading…
Reference in New Issue