minor tweak filter tranpose
This commit is contained in:
parent
775e48abb2
commit
3e691046dc
|
|
@ -14,7 +14,7 @@ constexpr uint WARPSIZE = 32;
|
|||
#define CUDA_NCHW_2_NHWC_TILE_DIM 32
|
||||
#define CUDA_NCHW_2_NHWC_BLOCK_NM 8
|
||||
#define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8
|
||||
#define CUDA_NCHW_2_NHWC_BLOCK_C 64
|
||||
#define CUDA_NCHW_2_NHWC_BLOCK_C 32
|
||||
|
||||
|
||||
//currently not use; in future for split-k kernels
|
||||
|
|
@ -58,12 +58,13 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
|
|||
int ty = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y;
|
||||
|
||||
__shared__ src_T tile[CUDA_NCHW_2_NHWC_TILE_DIM][CUDA_NCHW_2_NHWC_TILE_DIM];
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){
|
||||
|
||||
const unsigned int imat = blockIdx.z * CUDA_NCHW_2_NHWC_BLOCK_NM + i;
|
||||
if(imat >= nmat)
|
||||
break;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){
|
||||
if(x < ne01 && y + j < ne00){
|
||||
const int row = threadIdx.y+j;
|
||||
|
|
@ -72,7 +73,7 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
|
|||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){
|
||||
if(ty + j < ne01 && tx < ne00){
|
||||
const int col = (threadIdx.y+j) ^ threadIdx.x;
|
||||
|
|
|
|||
Loading…
Reference in New Issue