increase maximum split factor to 16; use better heuristics to choose split-K factor, reducing tail effect

This commit is contained in:
bssrdf 2025-11-10 11:47:56 -05:00
parent 496c3599c6
commit 1fdcb05dc8
2 changed files with 45 additions and 33 deletions

View File

@ -1,4 +1,5 @@
// #include <cuda_runtime.h> // #include <cuda_runtime.h>
#include <algorithm>
#include "ggml.h" #include "ggml.h"
#include "common.cuh" #include "common.cuh"
#include "convert.cuh" #include "convert.cuh"
@ -951,61 +952,72 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
// if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
if (BlocksM * BlocksN < (unsigned int)nsm){ if (BlocksM * BlocksN < 2*(unsigned int)nsm){
int j, max_remaining_waves = -1, candidate = -1;
int ks = min(12, nsm / (BlocksM * BlocksN)); int ks = min(16, nsm / (BlocksM * BlocksN));
int j; if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5)
bool can_split = false; ks = 16;
for (j = ks; j >= 2; j--){ for (j = 2; j <= ks; j++){
const int remainder = (BlocksM * BlocksN * j) % nsm;
if ((P.c * P.r * P.s) % (8*j) == 0){ if ((P.c * P.r * P.s) % (8*j) == 0){
can_split = true; if (remainder == 0) {
break; candidate = j;
max_remaining_waves = 0;
break;
} else if (remainder > max_remaining_waves) {
max_remaining_waves = remainder;
candidate = j;
}
} }
} }
if(can_split){
if(candidate != -1){
j = candidate;
// printf(" choosing %d, %d \n", j, max_remaining_waves);
if (j == 2) { if (j == 2) {
const unsigned int ksplit = 2; launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 2,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 3) { } else if (j == 3) {
const unsigned int ksplit = 3; launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 3,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 4) { } else if (j == 4) {
const unsigned int ksplit = 4; launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 4,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 5) { } else if (j == 5) {
const unsigned int ksplit = 5; launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 5,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 6) { } else if (j == 6) {
const unsigned int ksplit = 6; launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 6,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 7) { } else if (j == 7) {
const unsigned int ksplit = 7; launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 7,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 8) { } else if (j == 8) {
const unsigned int ksplit = 8; launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 8,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 9) { } else if (j == 9) {
const unsigned int ksplit = 9; launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 9,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 10) { } else if (j == 10) {
const unsigned int ksplit = 10; launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 10,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 11) { } else if (j == 11) {
const unsigned int ksplit = 11; launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 11,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if(j == 12) { } else if (j == 12) {
const unsigned int ksplit = 12; launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 12,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit, ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 13) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 13,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 14) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 14,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 15) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 15,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 16) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 16,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} }
return; return;

View File

@ -653,7 +653,7 @@ int main(void)
int k = 0; int k = 0;
// for (auto c : configs_sdxl_1024){ // for (auto c : configs_sdxl_768){
for (auto c : configs){ for (auto c : configs){
test_model model; test_model model;
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),