increase maximum split factor to 16; use better heuristics to choose split-K factor, reducing tail effect

This commit is contained in:
bssrdf 2025-11-10 11:47:56 -05:00
parent 496c3599c6
commit 1fdcb05dc8
2 changed files with 45 additions and 33 deletions

View File

@ -1,4 +1,5 @@
// #include <cuda_runtime.h>
#include <algorithm>
#include "ggml.h"
#include "common.cuh"
#include "convert.cuh"
@ -951,61 +952,72 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
// if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
if (BlocksM * BlocksN < (unsigned int)nsm){
int ks = min(12, nsm / (BlocksM * BlocksN));
int j;
bool can_split = false;
for (j = ks; j >= 2; j--){
if (BlocksM * BlocksN < 2*(unsigned int)nsm){
int j, max_remaining_waves = -1, candidate = -1;
int ks = min(16, nsm / (BlocksM * BlocksN));
if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5)
ks = 16;
for (j = 2; j <= ks; j++){
const int remainder = (BlocksM * BlocksN * j) % nsm;
if ((P.c * P.r * P.s) % (8*j) == 0){
can_split = true;
break;
if (remainder == 0) {
candidate = j;
max_remaining_waves = 0;
break;
} else if (remainder > max_remaining_waves) {
max_remaining_waves = remainder;
candidate = j;
}
}
}
if(can_split){
if(candidate != -1){
j = candidate;
// printf(" choosing %d, %d \n", j, max_remaining_waves);
if (j == 2) {
const unsigned int ksplit = 2;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 2,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 3) {
const unsigned int ksplit = 3;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 3,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 4) {
const unsigned int ksplit = 4;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 4,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 5) {
const unsigned int ksplit = 5;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 5,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 6) {
const unsigned int ksplit = 6;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 6,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 7) {
const unsigned int ksplit = 7;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 7,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 8) {
const unsigned int ksplit = 8;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 8,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 9) {
const unsigned int ksplit = 9;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 9,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 10) {
const unsigned int ksplit = 10;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 10,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 11) {
const unsigned int ksplit = 11;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 11,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if(j == 12) {
const unsigned int ksplit = 12;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
} else if (j == 12) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 12,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 13) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 13,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 14) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 14,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 15) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 15,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 16) {
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 16,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
}
return;

View File

@ -653,7 +653,7 @@ int main(void)
int k = 0;
// for (auto c : configs_sdxl_1024){
// for (auto c : configs_sdxl_768){
for (auto c : configs){
test_model model;
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),