fixed a bug now split-k can choose a better split factor

This commit is contained in:
bssrdf 2025-11-08 16:45:59 -05:00
parent 9cbc099493
commit a1fb3c1509
3 changed files with 24 additions and 10 deletions

View File

@ -856,13 +856,10 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
// const unsigned int ksplit = 6;
// if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
printf("split factor info = %d, %d, %d \n", BlocksM, BlocksN, nsm / (BlocksM * BlocksN));
if (BlocksM * BlocksN < nsm && nsm / (BlocksM * BlocksN) <= 8 ){
if (BlocksM * BlocksN < nsm){
int ks = nsm / (BlocksM * BlocksN);
printf("split factor init = %d \n", ks);
int j;
bool can_split = false;
for (j = ks; j >= 2; j--){
@ -872,7 +869,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
}
}
if(can_split){
printf("split factor = %d \n", j);
if (j == 2) {
const unsigned int ksplit = 2;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
@ -901,6 +897,22 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
const unsigned int ksplit = 8;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 9) {
const unsigned int ksplit = 9;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 10) {
const unsigned int ksplit = 10;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 11) {
const unsigned int ksplit = 11;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else {
const unsigned int ksplit = 12;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
}
return;
}

View File

@ -72,7 +72,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
@ -273,7 +273,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
const unsigned int src_index = thread_row * src_stride + ki;
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);

View File

@ -300,7 +300,9 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
// std::make_tuple(640,640,104,152,3,3),
// std::make_tuple(960,320,104,152,3,3),
// std::make_tuple(1280,1280,26,38,3,3),
std::make_tuple(1920,640,32,32,3,3)
// std::make_tuple(1920,640,32,32,3,3)
std::make_tuple(1280,1280,16,16,3,3),
// std::make_tuple(320,640,32,32,3,3),
// std::make_tuple(4,320,96,128,3,3),
// std::make_tuple(320,4,96,128,3,3),
// std::make_tuple(4,320,64,96,3,3),
@ -651,8 +653,8 @@ int main(void)
int k = 0;
// for (auto c : configs_sdxl_512){
for (auto c : configs){
for (auto c : configs_sdxl_768){
// for (auto c : configs){
test_model model;
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),
std::get<3>(c), std::get<4>(c), std::get<5>(c), true);