tuned block dimensions for filter tranpose

This commit is contained in:
bssrdf 2025-11-17 11:45:01 -05:00
parent 3e691046dc
commit 9bb5eb30e5
1 changed files with 19 additions and 27 deletions

View File

@ -14,7 +14,7 @@ constexpr uint WARPSIZE = 32;
#define CUDA_NCHW_2_NHWC_TILE_DIM 32
#define CUDA_NCHW_2_NHWC_BLOCK_NM 8
#define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8
#define CUDA_NCHW_2_NHWC_BLOCK_C 32
#define CUDA_NCHW_2_NHWC_BLOCK_C 64
//currently not use; in future for split-k kernels
@ -86,7 +86,6 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
template <typename src_T, typename dst_T, const unsigned int mask, const int rs, const unsigned int blk_c>
static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
const int64_t nmat = ne / (ne00 * ne01);
const int64_t n = ne00 * ne01;
const unsigned int tx = threadIdx.x;
@ -97,32 +96,26 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
__shared__ src_T tile[rs*blk_c];
#pragma unroll
for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){
const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i;
if(imat >= nmat)
break;
#pragma unroll
for (unsigned int j = 0; j < rs; j++){
const unsigned int row = (j * blk + tx) % rs;
const unsigned int col = (j * blk + tx) / rs;
const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk + tx;
unsigned int idx = row * blk_c + col;
idx = idx ^ ((idx & mask) >> 4);
if (src_index < ne && tx < blk) {
tile[idx] = src[src_index];
}
for (unsigned int j = 0; j < rs; j++){
const unsigned int row = (j * blk + tx) % rs;
const unsigned int col = (j * blk + tx) / rs;
const unsigned int src_index = by*n + bx * blk_c * rs + j * blk + tx;
unsigned int idx = row * blk_c + col;
idx = idx ^ ((idx & mask) >> 4);
if (src_index < ne && tx < blk) {
tile[idx] = src[src_index];
}
__syncthreads();
}
__syncthreads();
#pragma unroll
for (unsigned int j = 0; j < rs; j++){
const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx;
if(dst_index < ne && tx < blk){
unsigned int idx = j*blk_c + tx;
idx = idx ^ ((idx & mask) >> 4);
dst[dst_index] = ggml_cuda_cast<dst_T>(tile[idx]);
}
for (unsigned int j = 0; j < rs; j++){
const unsigned int dst_index = by*n + j*ne00 + bx*blk_c + tx;
if(dst_index < ne && tx < blk){
unsigned int idx = j*blk_c + tx;
idx = idx ^ ((idx & mask) >> 4);
dst[dst_index] = ggml_cuda_cast<dst_T>(tile[idx]);
}
}
}
@ -1222,14 +1215,13 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
ne = P.c * P.r * P.s * P.k;
ne01 = P.r * P.s;
// ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id));
if (ne01 > 1){
kernel_f16.alloc(ne);
dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C,
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM,
1) ;
ne/(ne00*ne01),
1) ;
if (ne01 == 25) {
constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C);
NCHW2NHWC<half, half, mask, 25, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);