fixed a bug now split-k can choose a better split factor
This commit is contained in:
parent
9cbc099493
commit
a1fb3c1509
|
|
@ -856,13 +856,10 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
||||
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
// const unsigned int ksplit = 6;
|
||||
// if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
|
||||
printf("split factor info = %d, %d, %d \n", BlocksM, BlocksN, nsm / (BlocksM * BlocksN));
|
||||
if (BlocksM * BlocksN < nsm && nsm / (BlocksM * BlocksN) <= 8 ){
|
||||
if (BlocksM * BlocksN < nsm){
|
||||
|
||||
int ks = nsm / (BlocksM * BlocksN);
|
||||
printf("split factor init = %d \n", ks);
|
||||
int j;
|
||||
bool can_split = false;
|
||||
for (j = ks; j >= 2; j--){
|
||||
|
|
@ -872,7 +869,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
}
|
||||
}
|
||||
if(can_split){
|
||||
printf("split factor = %d \n", j);
|
||||
if (j == 2) {
|
||||
const unsigned int ksplit = 2;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
|
|
@ -901,6 +897,22 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
const unsigned int ksplit = 8;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 9) {
|
||||
const unsigned int ksplit = 9;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 10) {
|
||||
const unsigned int ksplit = 10;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else if (j == 11) {
|
||||
const unsigned int ksplit = 11;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
} else {
|
||||
const unsigned int ksplit = 12;
|
||||
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
|
||||
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
@ -273,7 +273,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
const unsigned int src_index = thread_row * src_stride + ki;
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
|
|||
|
|
@ -300,7 +300,9 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
|
|||
// std::make_tuple(640,640,104,152,3,3),
|
||||
// std::make_tuple(960,320,104,152,3,3),
|
||||
// std::make_tuple(1280,1280,26,38,3,3),
|
||||
std::make_tuple(1920,640,32,32,3,3)
|
||||
// std::make_tuple(1920,640,32,32,3,3)
|
||||
std::make_tuple(1280,1280,16,16,3,3),
|
||||
// std::make_tuple(320,640,32,32,3,3),
|
||||
// std::make_tuple(4,320,96,128,3,3),
|
||||
// std::make_tuple(320,4,96,128,3,3),
|
||||
// std::make_tuple(4,320,64,96,3,3),
|
||||
|
|
@ -651,8 +653,8 @@ int main(void)
|
|||
|
||||
int k = 0;
|
||||
|
||||
// for (auto c : configs_sdxl_512){
|
||||
for (auto c : configs){
|
||||
for (auto c : configs_sdxl_768){
|
||||
// for (auto c : configs){
|
||||
test_model model;
|
||||
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),
|
||||
std::get<3>(c), std::get<4>(c), std::get<5>(c), true);
|
||||
|
|
|
|||
Loading…
Reference in New Issue