From 1fdcb05dc8740864e47468fe4946d8bfea823b06 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Mon, 10 Nov 2025 11:47:56 -0500 Subject: [PATCH] increase maximum split factor to 16; use better heuristics to choose split-K factor, reducing tail effect --- ggml/src/ggml-cuda/conv2d-implicit.cu | 76 ++++++++++++++++----------- tests/test-conv2d.cpp | 2 +- 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 99fa1925d5..09f317e2d3 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1,4 +1,5 @@ // #include +#include #include "ggml.h" #include "common.cuh" #include "convert.cuh" @@ -951,61 +952,72 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { - if (BlocksM * BlocksN < (unsigned int)nsm){ - - int ks = min(12, nsm / (BlocksM * BlocksN)); - int j; - bool can_split = false; - for (j = ks; j >= 2; j--){ + if (BlocksM * BlocksN < 2*(unsigned int)nsm){ + int j, max_remaining_waves = -1, candidate = -1; + int ks = min(16, nsm / (BlocksM * BlocksN)); + if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5) + ks = 16; + for (j = 2; j <= ks; j++){ + const int remainder = (BlocksM * BlocksN * j) % nsm; if ((P.c * P.r * P.s) % (8*j) == 0){ - can_split = true; - break; + if (remainder == 0) { + candidate = j; + max_remaining_waves = 0; + break; + } else if (remainder > max_remaining_waves) { + max_remaining_waves = remainder; + candidate = j; + } } } - if(can_split){ + + if(candidate != -1){ + j = candidate; + // printf(" choosing %d, %d \n", j, max_remaining_waves); if (j == 2) { - const unsigned int ksplit = 2; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 3) { - const unsigned int ksplit = 3; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 4) { - const unsigned int ksplit = 4; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 5) { - const unsigned int ksplit = 5; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 6) { - const unsigned int ksplit = 6; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 7) { - const unsigned int ksplit = 7; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 8) { - const unsigned int ksplit = 8; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 9) { - const unsigned int ksplit = 9; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 10) { - const unsigned int ksplit = 10; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } else if (j == 11) { - const unsigned int ksplit = 11; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); - } else if(j == 12) { - const unsigned int ksplit = 12; - launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 13) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 14) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 15) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 16) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); } return; diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index b3d5e8c724..a5f2a6aef6 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -653,7 +653,7 @@ int main(void) int k = 0; - // for (auto c : configs_sdxl_1024){ + // for (auto c : configs_sdxl_768){ for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),