conv3d WIP: turn on tensor cores; NCDHW2NDHWC to be worked out

This commit is contained in:
bssrdf 2025-11-02 15:15:49 -05:00
parent a5b68bcea7
commit 3f5c5045da
3 changed files with 182 additions and 130 deletions

View File

@ -62,28 +62,6 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
}
}
template<const int layout>
__device__ int4 inputIndices(const uint kidx, param_t param) {
const uint cur0 = fastdiv(kidx,
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
const uint cur0_res = fastmodulo(kidx,
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
const uint cur1 = fastdiv(cur0_res,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
const uint cur1_res = fastmodulo(cur0_res,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
const uint cur2 = fastdiv(cur1_res,
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint cur3 = fastmodulo(cur1_res,
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const uint curC = layout == 0 ? cur3 : cur0;
const uint curT = layout == 0 ? cur0 : cur1;
const uint curR = layout == 0 ? cur1 : cur2;
const uint curS = layout == 0 ? cur2 : cur3;
return make_int4(curC, curT, curR, curS);
}
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
const int WNITER, const int TM, const int TN, const int NUM_THREADS,
@ -553,7 +531,6 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
}
}
#if 0
template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k, unsigned int smem_stride>
__device__ __forceinline__ void ldmatrix_a(
@ -805,13 +782,16 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
const param_t param) {
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
constexpr unsigned int MMA_M = 16;
constexpr unsigned int MMA_N = 8;
constexpr unsigned int MMA_M = 16;
constexpr unsigned int MMA_N = 8;
const uint PQZ = param.Oh * param.Ow * param.Od;
const unsigned int K = param.c * param.r * param.s;
const unsigned int K = param.c * param.r * param.s * param.t;
const uint weightKOffset = K; //param.c * param.r * param.s * param.t;
const uint inChannelOffset = param.c * param.w;
const uint weightKOffset = param.c * param.r * param.s;
const uint inDepthOffset = param.h * param.c * param.w;
const uint inNOffset = param.c * param.w * param.h * param.d;
// loop bounds, constexpr where possible allows for loop unrolling
constexpr unsigned int mma_tiles_per_warp_k = 4;
@ -863,7 +843,7 @@ constexpr unsigned int MMA_N = 8;
const half* A_block_gmem = input;
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inChannelOffset, param);
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inNOffset, inDepthOffset, inChannelOffset, param);
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
int offset_direction = 1;
@ -874,7 +854,8 @@ constexpr unsigned int MMA_N = 8;
if (block_k != num_block_tiles_k){
const half* A_block_gmem = input;
const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param);
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK,
inNOffset, inDepthOffset, inChannelOffset, param);
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
}
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
@ -954,10 +935,13 @@ constexpr unsigned int MMA_N = 8;
for (int j = 0; j < 4; ++j){
const uint row = m_idx + subk + i * WN / 2;
const uint gemm_i = n_idx + j*32;
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
if(n < param.n && row < param.k && col < param.Oh * param.Ow){
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
// const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
// const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
const int n = fastdiv(gemm_i, param.PQZ_fastdiv);
const int col = fastmodulo(gemm_i, param.PQZ_fastdiv);
if(n < param.n && row < param.k && col < PQZ){
// const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
const uint outOffset = (n * param.k + row) * PQZ + col;
uint idx = output_lds_addr + subk + j*32*BN/2;
idx = idx ^ ((idx & 0b1110000000) >> 4);
output[outOffset] = smemoutput[idx];
@ -974,8 +958,6 @@ constexpr unsigned int MMA_N = 8;
#endif
}
#endif
#define NUM_VARIANTS 4
/*
@ -1021,64 +1003,64 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
// if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) {
if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) {
// int id = ggml_cuda_get_device();
int id = ggml_cuda_get_device();
// int64_t ne = P.c * P.h * P.w * P.n;
// int64_t ne00 = P.c;
// int64_t ne01 = P.h * P.w;
// ggml_cuda_pool_alloc<half> input_f16(ctx.pool(id), ne);
int64_t ne = P.c * P.h * P.w * P.n;
int64_t ne00 = P.c;
int64_t ne01 = P.h * P.w;
ggml_cuda_pool_alloc<half> input_f16(ctx.pool(id), ne);
// dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
// (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
// (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
// dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1);
// NCHW2NHWC<float, half><<<dimGrid, dimBlock, 0, st>>>(X_D, input_f16.get(), ne, ne00, ne01);
dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1);
NCHW2NHWC<float, half><<<dimGrid, dimBlock, 0, st>>>(X_D, input_f16.get(), ne, ne00, ne01);
// ne = P.c * P.r * P.s * P.k;
// ne01 = P.r * P.s;
// ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
// dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
// (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
// (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
// NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
ne = P.c * P.r * P.s * P.k;
ne01 = P.r * P.s;
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
// const half *X_H = input_f16.get();
// const half *K_H = kernel_f16.get();
// ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
const half *X_H = input_f16.get();
const half *K_H = kernel_f16.get();
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Od *P.Oh * P.Ow * P.n);
// constexpr unsigned int BM_dim = 256;
// constexpr unsigned int BN_dim = 256;
// constexpr unsigned int BK_dim = 32;
constexpr unsigned int BM_dim = 256;
constexpr unsigned int BN_dim = 256;
constexpr unsigned int BK_dim = 32;
// constexpr unsigned int WARPS_PER_BLOCK_M = 2;
// constexpr unsigned int WARPS_PER_BLOCK_N = 4;
// constexpr unsigned int WARPS_PER_BLOCK_K = 4;
constexpr unsigned int WARPS_PER_BLOCK_M = 2;
constexpr unsigned int WARPS_PER_BLOCK_N = 4;
constexpr unsigned int WARPS_PER_BLOCK_K = 4;
// constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
// constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
// constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
// const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim;
// const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
// constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
// constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N;
// constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
// const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
const unsigned int BlocksM = (P.n * P.Oh * P.Ow * P.Od + BM_dim - 1) / BM_dim;
const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N;
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
// cudaFuncSetAttribute(conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
// cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
// dim3 gridDim(BlocksN, BlocksM);
// dim3 blockDim(ThreadsN, ThreadsM);
cudaFuncSetAttribute(conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM);
dim3 blockDim(ThreadsN, ThreadsM);
// conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim,
// WM_dim, WN_dim, WK_dim, NumThreads>
// <<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
// const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
// to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
// } else{
conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim,
WM_dim, WN_dim, WK_dim, NumThreads>
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow *P.Od * P.n, st);
} else{
conv3d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
// }
}
}

View File

@ -35,6 +35,29 @@ typedef struct{
} param_t;
template<const int layout>
__device__ __forceinline__ int4 inputIndices(const unsigned int kidx, param_t param) {
const unsigned int cur0 = fastdiv(kidx,
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
const unsigned int cur0_res = fastmodulo(kidx,
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
const unsigned int cur1 = fastdiv(cur0_res,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
const unsigned int cur1_res = fastmodulo(cur0_res,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
const unsigned int cur2 = fastdiv(cur1_res,
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const unsigned int cur3 = fastmodulo(cur1_res,
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
const unsigned int curC = layout == 0 ? cur3 : cur0;
const unsigned int curT = layout == 0 ? cur0 : cur1;
const unsigned int curR = layout == 0 ? cur1 : cur2;
const unsigned int curS = layout == 0 ? cur2 : cur3;
return make_int4(curC, curT, curR, curS);
}
// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel
template<unsigned int TILE_ROWS,
@ -67,18 +90,24 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
// const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
const unsigned int kidx = thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
const int curC = curIdx.x;
const int curT = curIdx.y;
const int curR = curIdx.z;
const int curS = curIdx.w;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
// apply swizzle to the dst index
const unsigned int src_index = thread_row * src_stride + thread_col * 8;
const unsigned int src_index = thread_row * src_stride + kidx;
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
// TODO: move some checks outside of loop?
if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
@ -102,6 +131,8 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
const half* src,
half* dst,
// const unsigned int src_stride,
const unsigned int inNOffset,
const unsigned int inDepthOffset,
const unsigned int inChannelOffset,
param_t param
)
@ -129,28 +160,43 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
const unsigned int kidx = thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
#pragma unroll
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
unsigned int inOffset = n * param.c * param.h * param.w;
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
int curH = posh_ori + curR * param.dilation1; // input h
int curW = posw_ori + curS * param.dilation0; // input w
unsigned int n = fastdiv(gemm_i, param.PQZ_fastdiv);
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
const int posd_ori = fastdiv(npqz_res, param.OHOW_fastdiv) * param.stride2 - param.padding2;
const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv);
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
// unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
// unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
// int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
// int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
// unsigned int inOffset = n * inNOffset;
// const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// int curH = posh_ori + curR * param.dilation1; // input h
// int curW = posw_ori + curS * param.dilation0; // input w
const int curD = posd_ori + curIdx.y * param.dilation2; // input d
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
// apply swizzle to the dst index
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.r && curS < param.s && curC < param.c){
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
// curR < param.r && curS < param.s && curC < param.c){
// const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
} else{
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
}
@ -174,6 +220,8 @@ __device__ __forceinline__ void tileMemcpyLoadA(
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
// const unsigned int src_stride,
const unsigned int block_k,
const unsigned int inNOffset,
const unsigned int inDepthOffset,
const unsigned int inChannelOffset,
param_t param
){
@ -196,23 +244,38 @@ __device__ __forceinline__ void tileMemcpyLoadA(
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
const unsigned int kidx = block_k + thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
unsigned int inOffset = n * param.c * param.h * param.w;
const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
int curH = posh_ori + curR * param.dilation1; // input h
int curW = posw_ori + curS * param.dilation0; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.r && curS < param.s && curC < param.c){
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
// unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
// unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
// int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
// int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
unsigned int n = fastdiv(gemm_i, param.PQZ_fastdiv);
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
const int posd_ori = fastdiv(npqz_res, param.OHOW_fastdiv) * param.stride2 - param.padding2;
const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv);
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
// unsigned int inOffset = n * param.c * param.h * param.w;
// const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// int curH = posh_ori + curR * param.dilation1; // input h
// int curW = posw_ori + curS * param.dilation0; // input w
const int curD = posd_ori + curIdx.y * param.dilation2; // input d
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
// curR < param.r && curS < param.s && curC < param.c){
// const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
} else{
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
}
@ -259,14 +322,21 @@ __device__ __forceinline__ void tileMemcpyLoadB(
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
// const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
const unsigned int kidx = block_k + thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
const int curC = curIdx.x;
const int curT = curIdx.y;
const int curR = curIdx.z;
const int curS = curIdx.w;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
// if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
// TODO : move some checks outside of the loop
if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);

View File

@ -323,13 +323,13 @@ int main(void)
// std::make_tuple(960,320,104,152,3,3),
// std::make_tuple(1280,1280,26,38,3,3),
std::make_tuple(320,1280,26,38,8,3,3,3),
std::make_tuple(1280,1280,26,38,8,3,3,3),
std::make_tuple(320,1280,52,76,8,3,3,3),
std::make_tuple(1280,1280,52,76,8,3,3,3),
std::make_tuple(320,1280,104,152,8,3,3,3),
std::make_tuple(1280,1280,104,152,8,3,3,3),
std::make_tuple(320,1280,208,304,4,3,3,3),
std::make_tuple(640,1280,208,304,4,3,3,3),
// std::make_tuple(1280,1280,26,38,8,3,3,3),
// std::make_tuple(320,1280,52,76,8,3,3,3),
// std::make_tuple(1280,1280,52,76,8,3,3,3),
// std::make_tuple(320,1280,104,152,8,3,3,3),
// std::make_tuple(1280,1280,104,152,8,3,3,3),
// std::make_tuple(320,1280,208,304,4,3,3,3),
// std::make_tuple(640,1280,208,304,4,3,3,3),
// std::make_tuple(1280,1280,26,38,1,1),
// std::make_tuple(256,128,768,1024,3,3),
// std::make_tuple(128,3,768,1024,3,3),
@ -367,7 +367,7 @@ int main(void)
struct ggml_cgraph * gf_res_0 = NULL;
int iterations = 20;
int iterations = 0;
double run_time0;
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations,