diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index da2a6868a9..ad3f50c85e 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -10,6 +10,7 @@ constexpr uint WARPSIZE = 32; #define CUDA_NCHW_2_NHWC_TILE_DIM 32 #define CUDA_NCHW_2_NHWC_BLOCK_NM 8 #define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8 +#define CUDA_NCHW_2_NHWC_BLOCK_C 64 //currently not use; in future for split-k kernels @@ -64,6 +65,45 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co } } +template +static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ + + const int64_t nmat = ne / (ne00 * ne01); + const int64_t n = ne00 * ne01; + + const unsigned int tx = threadIdx.x; + const unsigned int bx = blockIdx.x; + const unsigned int by = blockIdx.y; + // int y = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + // int tx = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; // transpose block offset + // int ty = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + + __shared__ src_T tile[rs*blk_c]; + + for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ + + const unsigned int imat = by * CUDA_NCHW_2_NHWC_BLOCK_NM + i; + if(imat >= nmat) + break; + for (unsigned int j = 0; j < rs; j++){ + const unsigned int row = (j * blk_c + tx) % rs; + const unsigned int col = (j * blk_c + tx) / rs; + const unsigned int src_index = imat*n + bx * blk_c * rs + j * blk_c + tx; + if (src_index < ne) { + tile[row * blk_c + col] = src[src_index]; + } + } + __syncthreads(); + + for (unsigned int j = 0; j < rs; j++){ + const unsigned int dst_index = imat*n + j*ne00 + bx*blk_c + tx; + if(dst_index < ne){ + dst[dst_index] = ggml_cuda_cast(tile[j*blk_c+tx]); + } + } + } +} + template 1 || P.s > 1)) { if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { int id = ggml_cuda_get_device(); @@ -826,13 +867,40 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa ne = P.c * P.r * P.s * P.k; ne01 = P.r * P.s; ggml_cuda_pool_alloc kernel_f16(ctx.pool(id), ne); - dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, - (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; - NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + if (ne01 > 1){ + dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM, + 1) ; + if (ne01 == 25) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + }else if (ne01 == 16) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + }else if (ne01 == 9) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 8) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 7) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 6) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 5) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 4) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 3) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else if (ne01 == 2) { + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } else { + dim3 dimGrid2((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } + } const half *X_H = input_f16.get(); - const half *K_H = kernel_f16.get(); + const half *K_H = ne01 == 1 ? K_D : kernel_f16.get(); constexpr unsigned int BM_dim = 256; constexpr unsigned int BN_dim = 256; diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index db43cf1847..844cce8923 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -299,9 +299,9 @@ static std::vector> configs = { // std::make_tuple(640,640,52,76,3,3), // std::make_tuple(640,640,104,152,3,3), // std::make_tuple(960,320,104,152,3,3), - // std::make_tuple(1280,1280,26,38,3,3), + std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), // std::make_tuple(320,640,32,32,3,3), // std::make_tuple(4,320,96,128,3,3), // std::make_tuple(320,4,96,128,3,3), @@ -653,8 +653,8 @@ int main(void) int k = 0; - // for (auto c : configs_sdxl_1024){ - for (auto c : configs){ + for (auto c : configs_sdxl_1024){ + // for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true);