From a1fb3c150941cd1b5fedea79aa2053c348415aff Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 8 Nov 2025 16:45:59 -0500 Subject: [PATCH] fixed a bug now split-k can choose a better split factor --- ggml/src/ggml-cuda/conv2d-implicit.cu | 22 +++++++++++++++++----- ggml/src/ggml-cuda/conv2d-implicit.cuh | 4 ++-- tests/test-conv2d.cpp | 8 +++++--- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1133626d14..1bf94476ab 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -856,13 +856,10 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - // const unsigned int ksplit = 6; // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { - printf("split factor info = %d, %d, %d \n", BlocksM, BlocksN, nsm / (BlocksM * BlocksN)); - if (BlocksM * BlocksN < nsm && nsm / (BlocksM * BlocksN) <= 8 ){ + if (BlocksM * BlocksN < nsm){ int ks = nsm / (BlocksM * BlocksN); - printf("split factor init = %d \n", ks); int j; bool can_split = false; for (j = ks; j >= 2; j--){ @@ -872,7 +869,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa } } if(can_split){ - printf("split factor = %d \n", j); if (j == 2) { const unsigned int ksplit = 2; launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 9) { + const unsigned int ksplit = 9; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 10) { + const unsigned int ksplit = 10; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 11) { + const unsigned int ksplit = 11; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else { + const unsigned int ksplit = 12; + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } return; } diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 85936e42c6..35764c5b63 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -72,7 +72,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -273,7 +273,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ const unsigned int src_index = thread_row * src_stride + ki; - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 5af7da0a91..c460ca7d87 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -300,7 +300,9 @@ static std::vector> configs = { // std::make_tuple(640,640,104,152,3,3), // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), - std::make_tuple(1920,640,32,32,3,3) + // std::make_tuple(1920,640,32,32,3,3) + std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(320,640,32,32,3,3), // std::make_tuple(4,320,96,128,3,3), // std::make_tuple(320,4,96,128,3,3), // std::make_tuple(4,320,64,96,3,3), @@ -651,8 +653,8 @@ int main(void) int k = 0; - // for (auto c : configs_sdxl_512){ - for (auto c : configs){ + for (auto c : configs_sdxl_768){ + // for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), std::get<4>(c), std::get<5>(c), true);