increase maximum split factor to 16; use better heuristics to choose split-K factor, reducing tail effect
This commit is contained in:
parent
496c3599c6
commit
1fdcb05dc8
|
|
@ -1,4 +1,5 @@
|
|||
// #include <cuda_runtime.h>
|
||||
#include <algorithm>
|
||||
#include "ggml.h"
|
||||
#include "common.cuh"
|
||||
#include "convert.cuh"
|
||||
|
|
@ -951,61 +952,72 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
// if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
|
||||
if (BlocksM * BlocksN < (unsigned int)nsm){
|
||||
|
||||
int ks = min(12, nsm / (BlocksM * BlocksN));
|
||||
int j;
|
||||
bool can_split = false;
|
||||
for (j = ks; j >= 2; j--){
|
||||
if (BlocksM * BlocksN < 2*(unsigned int)nsm){
|
||||
int j, max_remaining_waves = -1, candidate = -1;
|
||||
int ks = min(16, nsm / (BlocksM * BlocksN));
|
||||
if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5)
|
||||
ks = 16;
|
||||
for (j = 2; j <= ks; j++){
|
||||
const int remainder = (BlocksM * BlocksN * j) % nsm;
|
||||
if ((P.c * P.r * P.s) % (8*j) == 0){
|
||||
can_split = true;
|
||||
break;
|
||||
if (remainder == 0) {
|
||||
candidate = j;
|
||||
max_remaining_waves = 0;
|
||||
break;
|
||||
} else if (remainder > max_remaining_waves) {
|
||||
max_remaining_waves = remainder;
|
||||
candidate = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(can_split){
|
||||
|
||||
if(candidate != -1){
|
||||
j = candidate;
|
||||
// printf(" choosing %d, %d \n", j, max_remaining_waves);
|
||||
if (j == 2) {
|
||||
const unsigned int ksplit = 2;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 2,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 3) {
|
||||
const unsigned int ksplit = 3;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 3,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 4) {
|
||||
const unsigned int ksplit = 4;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 4,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 5) {
|
||||
const unsigned int ksplit = 5;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 5,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 6) {
|
||||
const unsigned int ksplit = 6;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 6,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 7) {
|
||||
const unsigned int ksplit = 7;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 7,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 8) {
|
||||
const unsigned int ksplit = 8;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 8,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 9) {
|
||||
const unsigned int ksplit = 9;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 9,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 10) {
|
||||
const unsigned int ksplit = 10;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 10,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 11) {
|
||||
const unsigned int ksplit = 11;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 11,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if(j == 12) {
|
||||
const unsigned int ksplit = 12;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
} else if (j == 12) {
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 12,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 13) {
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 13,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 14) {
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 14,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 15) {
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 15,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 16) {
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 16,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
}
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -653,7 +653,7 @@ int main(void)
|
|||
|
||||
int k = 0;
|
||||
|
||||
// for (auto c : configs_sdxl_1024){
|
||||
// for (auto c : configs_sdxl_768){
|
||||
for (auto c : configs){
|
||||
test_model model;
|
||||
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),
|
||||
|
|
|
|||
Loading…
Reference in New Issue