minor tweak filter tranpose

This commit is contained in:
bssrdf 2025-11-17 10:41:30 -05:00
parent 775e48abb2
commit 3e691046dc
1 changed files with 4 additions and 3 deletions

View File

@ -14,7 +14,7 @@ constexpr uint WARPSIZE = 32;
#define CUDA_NCHW_2_NHWC_TILE_DIM 32
#define CUDA_NCHW_2_NHWC_BLOCK_NM 8
#define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8
#define CUDA_NCHW_2_NHWC_BLOCK_C 64
#define CUDA_NCHW_2_NHWC_BLOCK_C 32
//currently not use; in future for split-k kernels
@ -58,12 +58,13 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
int ty = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y;
__shared__ src_T tile[CUDA_NCHW_2_NHWC_TILE_DIM][CUDA_NCHW_2_NHWC_TILE_DIM];
#pragma unroll
for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){
const unsigned int imat = blockIdx.z * CUDA_NCHW_2_NHWC_BLOCK_NM + i;
if(imat >= nmat)
break;
#pragma unroll
for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){
if(x < ne01 && y + j < ne00){
const int row = threadIdx.y+j;
@ -72,7 +73,7 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){
if(ty + j < ne01 && tx < ne00){
const int col = (threadIdx.y+j) ^ threadIdx.x;