broken for some test cases

This commit is contained in:
bssrdf 2025-11-08 14:51:45 -05:00
parent 64ead3fd4f
commit 9cbc099493
2 changed files with 127 additions and 69 deletions

View File

@ -779,6 +779,33 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
WNITER, TM, TN, NUM_THREADS, 1, false, 0><<<grid, thblock, 0, st>>>(X_D, K_D, Y_D, P);
}
template<const int BM, const int BN, const int BK,
const int WM, const int WN, const int WK, const int ksplit,
const unsigned int ThreadsM, const unsigned int ThreadsN,
const int NUM_THREADS>
static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx, const half *X_H, const half *K_H, float *Y_D,
const unsigned int BlocksM, const unsigned int BlocksN,
const unsigned int shmem_bytes,
const param_t P, cudaStream_t st){
int id = ggml_cuda_get_device();
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n);
cudaFuncSetAttribute(conv2d_implicit_kernel<BM, BN, BK, WM, WN, WK, ksplit, NUM_THREADS>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM, ksplit);
dim3 blockDim(ThreadsN, ThreadsM);
conv2d_implicit_kernel<BM, BN, BK,
WM, WN, WK, ksplit, NUM_THREADS><<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
const unsigned int nrows = P.n * P.k * P.Oh * P.Ow;
const unsigned int blockx = (nrows + 511) / 512;
const dim3 block_nums(blockx, 1, 1);
const dim3 block_dims(512, 1, 1);
reduce_f32<half, float><<<block_nums, block_dims, 0, st>>>(Y_H.get(), Y_D, nrows, ksplit);
}
static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) {
@ -829,39 +856,67 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const unsigned int ksplit = 8;
if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n);
// const unsigned int ksplit = 6;
// if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
printf("split factor info = %d, %d, %d \n", BlocksM, BlocksN, nsm / (BlocksM * BlocksN));
if (BlocksM * BlocksN < nsm && nsm / (BlocksM * BlocksN) <= 8 ){
cudaFuncSetAttribute(conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit, NumThreads>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM, ksplit);
dim3 blockDim(ThreadsN, ThreadsM);
conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim,
WM_dim, WN_dim, WK_dim, ksplit, NumThreads>
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
const unsigned int nrows = P.n * P.k * P.Oh * P.Ow;
const unsigned int blockx = (nrows + 511) / 512;
const dim3 block_nums(blockx, 1, 1);
const dim3 block_dims(512, 1, 1);
reduce_f32<half, float><<<block_nums, block_dims, 0, st>>>(Y_H.get(), Y_D, nrows, ksplit);
} else {
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
cudaFuncSetAttribute(conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 0, NumThreads>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM);
dim3 blockDim(ThreadsN, ThreadsM);
conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim,
WM_dim, WN_dim, WK_dim, 0, NumThreads>
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
int ks = nsm / (BlocksM * BlocksN);
printf("split factor init = %d \n", ks);
int j;
bool can_split = false;
for (j = ks; j >= 2; j--){
if ((P.c * P.r * P.s) % (8*j) == 0){
can_split = true;
break;
}
}
if(can_split){
printf("split factor = %d \n", j);
if (j == 2) {
const unsigned int ksplit = 2;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 3) {
const unsigned int ksplit = 3;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 4) {
const unsigned int ksplit = 4;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 5) {
const unsigned int ksplit = 5;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 6) {
const unsigned int ksplit = 6;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 7) {
const unsigned int ksplit = 7;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 8) {
const unsigned int ksplit = 8;
launch_conv2d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
}
return;
}
}
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
cudaFuncSetAttribute(conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 0, NumThreads>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM);
dim3 blockDim(ThreadsN, ThreadsM);
conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim,
WM_dim, WN_dim, WK_dim, 0, NumThreads>
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
} else{
conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
}

View File

@ -293,42 +293,38 @@ std::vector<float> compute_graph(const test_model & model, ggml_gallocr_t allocr
}
int main(void)
{
ggml_time_init();
double time_iter0 = 0.0, time_iter1 = 0.0;
std::vector<std::tuple<int, int, int, int, int, int>> configs = {
std::make_tuple(64,64,48,64,3,3),
std::make_tuple(320,320,104,152,3,3),
std::make_tuple(640,640,52,76,3,3),
std::make_tuple(640,640,104,152,3,3),
std::make_tuple(960,320,104,152,3,3),
std::make_tuple(1280,1280,26,38,3,3),
std::make_tuple(4,320,96,128,3,3),
std::make_tuple(320,4,96,128,3,3),
std::make_tuple(4,320,64,96,3,3),
std::make_tuple(320,4,64,96,3,3),
std::make_tuple(640,640,96,128,3,3),
std::make_tuple(1280,1280,26,38,1,1),
std::make_tuple(256,128,768,1024,3,3),
std::make_tuple(128,3,768,1024,3,3),
std::make_tuple(256,128,768,1024,1,1),
std::make_tuple(512,256,384,512,1,1),
std::make_tuple(1280,640,52,76,3,3),
std::make_tuple(1920,1280,26,38,3,3),
std::make_tuple(2560,1280,26,38,3,3),
std::make_tuple(320,1280,26,38,3,3),
std::make_tuple(512,512,104,152,3,3),
std::make_tuple(512,512,208,304,3,3),
std::make_tuple(512,256,416,608,3,3),
std::make_tuple(256,128,832,1216,3,3),
std::make_tuple(256,256,832,1216,3,3),
std::make_tuple(32,64,58,58,3,3)
static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
// std::make_tuple(64,64,48,64,3,3),
// std::make_tuple(320,320,104,152,3,3),
// std::make_tuple(640,640,52,76,3,3),
// std::make_tuple(640,640,104,152,3,3),
// std::make_tuple(960,320,104,152,3,3),
// std::make_tuple(1280,1280,26,38,3,3),
std::make_tuple(1920,640,32,32,3,3)
// std::make_tuple(4,320,96,128,3,3),
// std::make_tuple(320,4,96,128,3,3),
// std::make_tuple(4,320,64,96,3,3),
// std::make_tuple(320,4,64,96,3,3),
// std::make_tuple(640,640,96,128,3,3),
// std::make_tuple(1280,1280,26,38,1,1),
// std::make_tuple(256,128,768,1024,3,3),
// std::make_tuple(128,3,768,1024,3,3),
// std::make_tuple(256,128,768,1024,1,1),
// std::make_tuple(512,256,384,512,1,1),
// std::make_tuple(1280,640,52,76,3,3),
// std::make_tuple(1920,1280,26,38,3,3),
// std::make_tuple(2560,1280,26,38,3,3),
// std::make_tuple(320,1280,26,38,3,3),
// std::make_tuple(512,512,104,152,3,3),
// std::make_tuple(512,512,208,304,3,3),
// std::make_tuple(512,256,416,608,3,3),
// std::make_tuple(256,128,832,1216,3,3),
// std::make_tuple(256,256,832,1216,3,3),
// std::make_tuple(32,64,58,58,3,3)
// std::make_tuple(320,256,1024,1920)
};
std::vector<std::tuple<int, int, int, int, int, int>> configs_sdxl_512 = {
static std::vector<std::tuple<int, int, int, int, int, int>> configs_sdxl_512 = {
//512x512
std::make_tuple(4,320,64,64,3,3),
std::make_tuple(320,320,64,64,3,3),
@ -434,7 +430,7 @@ int main(void)
std::make_tuple(320,4,64,64,3,3)
};
std::vector<std::tuple<int, int, int, int, int, int>> configs_sdxl_768 = {
static std::vector<std::tuple<int, int, int, int, int, int>> configs_sdxl_768 = {
//768x768
std::make_tuple(4,320,96,96,3,3),
std::make_tuple(320,320,96,96,3,3),
@ -540,7 +536,7 @@ int main(void)
std::make_tuple(320,4,96,96,3,3),
};
std::vector<std::tuple<int, int, int, int, int, int>> configs_sdxl_1024 = {
static std::vector<std::tuple<int, int, int, int, int, int>> configs_sdxl_1024 = {
//1024x1024
std::make_tuple(4,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
@ -646,10 +642,17 @@ int main(void)
std::make_tuple(320,4,128,128,3,3)
};
int main(void)
{
ggml_time_init();
double time_iter0 = 0.0, time_iter1 = 0.0;
int k = 0;
for (auto c : configs_sdxl_512){
// for (auto c : configs){
// for (auto c : configs_sdxl_512){
for (auto c : configs){
test_model model;
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),
std::get<3>(c), std::get<4>(c), std::get<5>(c), true);