restore split-k for small inputs

This commit is contained in:
bssrdf 2025-11-15 23:59:38 -05:00
parent 3591e83db9
commit 721fa41076
2 changed files with 42 additions and 43 deletions

View File

@ -83,47 +83,48 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
}
//*** broken, has bugs ***
// template <typename src_T, typename dst_T, const unsigned int mask, const int rs, const unsigned int blk_c>
// static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
template <typename src_T, typename dst_T, const unsigned int mask, const int rs, const unsigned int blk_c>
static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
// const int64_t nmat = ne / (ne00 * ne01);
// const int64_t n = ne00 * ne01;
const int64_t nmat = ne / (ne00 * ne01);
const int64_t n = ne00 * ne01;
// const unsigned int tx = threadIdx.x;
// const unsigned int bx = blockIdx.x;
// const unsigned int by = blockIdx.y;
const unsigned int tx = threadIdx.x;
const unsigned int bx = blockIdx.x;
const unsigned int by = blockIdx.y;
// __shared__ src_T tile[rs*blk_c];
__shared__ src_T tile[rs*blk_c];
// #pragma unroll
// for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){
#pragma unroll
for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){
// const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i;
// if(imat >= nmat)
// break;
// #pragma unroll
// for (unsigned int j = 0; j < rs; j++){
// const unsigned int row = (j * blk_c + tx) % rs;
// const unsigned int col = (j * blk_c + tx) / rs;
// const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx;
// unsigned int idx = row * blk_c + col;
// idx = idx ^ ((idx & mask) >> 4);
// if (src_index < ne) {
// tile[idx] = src[src_index];
// }
// }
// __syncthreads();
// #pragma unroll
// for (unsigned int j = 0; j < rs; j++){
// const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx;
// if(dst_index < ne){
// unsigned int idx = j*blk_c + tx;
// idx = idx ^ ((idx & mask) >> 4);
// dst[dst_index] = ggml_cuda_cast<dst_T>(tile[idx]);
// }
// }
// }
// }
const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i;
if(imat >= nmat)
break;
#pragma unroll
for (unsigned int j = 0; j < rs; j++){
const unsigned int row = (j * blk_c + tx) % rs;
const unsigned int col = (j * blk_c + tx) / rs;
const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx;
// const unsigned int src_index = imat*n + rs*ne00 + bx * blk_c + j * blk_c + tx;
unsigned int idx = row * blk_c + col;
// idx = idx ^ ((idx & mask) >> 4);
if (src_index < ne) {
tile[idx] = src[src_index];
}
}
__syncthreads();
#pragma unroll
for (unsigned int j = 0; j < rs; j++){
const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx;
if(dst_index < ne){
unsigned int idx = j*blk_c + tx;
// idx = idx ^ ((idx & mask) >> 4);
dst[dst_index] = ggml_cuda_cast<dst_T>(tile[idx]);
}
}
}
}
@ -1338,9 +1339,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id));
if (ne01 > 1){
kernel_f16.alloc(ne);
dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C,
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM,
1) ;
// dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C,
// (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM,
// 1) ;
// if (ne01 == 25) {
// constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C);
// NCHW2NHWC<half, half, mask, 25, CUDA_NCHW_2_NHWC_BLOCK_C><<<dimGrid1, CUDA_NCHW_2_NHWC_BLOCK_C, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
@ -1424,10 +1425,8 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
}
}
}
candidate = -1;
if(candidate != -1){
j = candidate;
printf("choosing %d \n", j);
if (j == 2) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 2,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);

View File

@ -319,7 +319,7 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
// std::make_tuple(960,320,104,152,3,3),
// std::make_tuple(1280,1280,26,38,3,3),
// std::make_tuple(1920,640,32,32,3,3)
// std::make_tuple(1280,1280,16,16,3,3),
std::make_tuple(1280,1280,16,16,3,3),
// std::make_tuple(32,12,141,133,3,3),
// std::make_tuple(32,6,141,133,3,3),
// std::make_tuple(32,12,141,121,3,3),
@ -330,7 +330,7 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
// std::make_tuple(256,12,16,16,3,3), //working
// std::make_tuple(32,12,16,16,3,3), //not working
// std::make_tuple(48,12,16,16,3,3), // not working
std::make_tuple(96,12,16,16,3,3), //not working
// std::make_tuple(96,12,16,16,3,3), //not working
// std::make_tuple(64,12,16,16,3,3), //working
// std::make_tuple(64,12,141,133,3,3), //working
// std::make_tuple(32,12,141,133,3,3), //working