From 6f44f471133564003649fafc84b7264b28d39317 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 5 Nov 2025 13:04:37 -0500 Subject: [PATCH] added split-k mode for skinny mnk shapes --- ggml/src/ggml-cuda/conv2d-implicit.cu | 108 +++++++++++++++++-------- ggml/src/ggml-cuda/conv2d-implicit.cuh | 41 ++++++---- tests/test-conv2d.cpp | 12 ++- 3 files changed, 103 insertions(+), 58 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 6bc93b2a57..216f895922 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -13,18 +13,19 @@ constexpr uint WARPSIZE = 32; //currently not use; in future for split-k kernels -// static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { -// const int row = blockIdx.x; -// const int col = threadIdx.x; +template +static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restrict__ dst, const int ncols, const int nrows) { + const int row = blockIdx.x; + const int col = threadIdx.x; -// float sum = 0.0f; -// if (row * blockDim.x + col < ncols) { -// for (int i = 0; i < nrows; ++i){ -// sum += x[i * ncols + row * blockDim.x + col]; -// } -// dst[row * blockDim.x + col] = sum; -// } -// } + float sum = 0.0f; + if (row * blockDim.x + col < ncols) { + for (int i = 0; i < nrows; ++i){ + sum += ggml_cuda_cast(x[i * ncols + row * blockDim.x + col]); + } + dst[row * blockDim.x + col] = ggml_cuda_cast(sum); + } +} template static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ @@ -705,26 +706,32 @@ __device__ __forceinline__ void ldmatrix_b( } template + const int WK, const int ksplit, const int NUM_THREADS> static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const half * __restrict__ kernel, half * __restrict__ output, const param_t param) { #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING -constexpr unsigned int MMA_M = 16; -constexpr unsigned int MMA_N = 8; - + constexpr unsigned int MMA_M = 16; + constexpr unsigned int MMA_N = 8; const unsigned int K = param.c * param.r * param.s; const uint inChannelOffset = param.c * param.w; - const uint weightKOffset = param.c * param.r * param.s; + const uint weightKOffset = K; // loop bounds, constexpr where possible allows for loop unrolling constexpr unsigned int mma_tiles_per_warp_k = 4; constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; - const unsigned int num_block_tiles_k = (K + (BK-1)) / BK; + const unsigned int z = blockIdx.z; + + const unsigned int ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; + const unsigned int start_k = (ksplit > 0) ? z * ks : 0; + const unsigned int end_k = min(start_k + ks, weightKOffset); + const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK; + + // calculate block/warp indices const unsigned int block_m = blockIdx.y; @@ -770,8 +777,8 @@ constexpr unsigned int MMA_N = 8; const half* A_block_gmem = input; const half* B_block_gmem = kernel + block_n * BN * weightKOffset; - tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); - tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, start_k, end_k, inChannelOffset, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, start_k, end_k, weightKOffset, param); int offset_direction = 1; @@ -781,8 +788,8 @@ constexpr unsigned int MMA_N = 8; if (block_k != num_block_tiles_k){ const half* A_block_gmem = input; const half* B_block_gmem = kernel + (block_n * BN * weightKOffset); - tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param); - tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inChannelOffset, param); + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, weightKOffset, param); } half* A_warp_tile = A_block_smem + (warp_m * WM * BK); half* B_warp_tile = B_block_smem + (warp_n * WN * BK); @@ -813,6 +820,8 @@ constexpr unsigned int MMA_N = 8; } + + if (block_k != num_block_tiles_k) { // switch smem buffers each iteration @@ -863,11 +872,18 @@ constexpr unsigned int MMA_N = 8; const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - if(n < param.n && row < param.k && col < param.Oh * param.Ow){ - const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + if (n < param.n && row < param.k && col < param.Oh * param.Ow) { uint idx = output_lds_addr + subk + j*32*BN/2; idx = idx ^ ((idx & 0b1110000000) >> 4); - output[outOffset] = smemoutput[idx]; + if constexpr (ksplit > 0) { + const uint outOffset = z * param.n * param.k * param.Oh * param.Ow + + n * param.k * param.Oh * param.Ow + + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[idx]; + } else { + const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[idx]; + } } } } @@ -952,7 +968,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa const half *X_H = input_f16.get(); const half *K_H = kernel_f16.get(); - ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); constexpr unsigned int BM_dim = 256; constexpr unsigned int BN_dim = 256; @@ -972,16 +987,41 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int NumThreads = ThreadsM * ThreadsN; const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); - cudaFuncSetAttribute(conv2d_implicit_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 - dim3 gridDim(BlocksN, BlocksM); - dim3 blockDim(ThreadsN, ThreadsM); + const unsigned int K2MN = 8; - conv2d_implicit_kernel - <<>>(X_H, K_H, Y_H.get(), P); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); + if (P.c * P.r * P.s > K2MN * P.n * P.Oh * P.Ow || P.c * P.r * P.s > K2MN * P.k) { + const unsigned int ksplit = 8; + ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n); + + cudaFuncSetAttribute(conv2d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM, ksplit); + dim3 blockDim(ThreadsN, ThreadsM); + + conv2d_implicit_kernel + <<>>(X_H, K_H, Y_H.get(), P); + + const unsigned int nrows = P.n * P.k * P.Oh * P.Ow; + const unsigned int blockx = (nrows + 511) / 512; + const dim3 block_nums(blockx, 1, 1); + const dim3 block_dims(512, 1, 1); + reduce_f32<<>>(Y_H.get(), Y_D, nrows, ksplit); + + } else { + ggml_cuda_pool_alloc Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n); + + cudaFuncSetAttribute(conv2d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM); + dim3 blockDim(ThreadsN, ThreadsM); + + conv2d_implicit_kernel + <<>>(X_H, K_H, Y_H.get(), P); + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st); + } } else{ conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 347ca12b3e..b242277eb0 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -32,6 +32,8 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleB( const half* src, half* dst, + const unsigned int start_k, + const unsigned int end_k, const unsigned int src_stride, param_t param ){ @@ -57,9 +59,9 @@ __device__ __forceinline__ void tileMemcpySwizzleB( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ @@ -68,7 +70,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -91,7 +93,8 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleA( const half* src, half* dst, - // const unsigned int src_stride, + const unsigned int start_k, + const unsigned int end_k, const unsigned int inChannelOffset, param_t param ) @@ -128,9 +131,9 @@ __device__ __forceinline__ void tileMemcpySwizzleA( int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; unsigned int inOffset = n * param.c * param.h * param.w; - const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w // apply swizzle to the dst index @@ -138,7 +141,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c){ + curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; } else{ @@ -164,6 +167,8 @@ __device__ __forceinline__ void tileMemcpyLoadA( float4 (&dst_reg)[ELEMENTS_PER_THREAD], // const unsigned int src_stride, const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, const unsigned int inChannelOffset, param_t param ){ @@ -194,13 +199,13 @@ __device__ __forceinline__ void tileMemcpyLoadA( int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; unsigned int inOffset = n * param.c * param.h * param.w; - const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c){ + curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; } else{ @@ -227,6 +232,8 @@ __device__ __forceinline__ void tileMemcpyLoadB( const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, const unsigned int src_stride, param_t param ){ @@ -249,14 +256,14 @@ __device__ __forceinline__ void tileMemcpyLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){ + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 48dcaa47d8..b5e7b18a2a 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -309,7 +309,6 @@ int main(void) std::make_tuple(4,320,64,96,3,3), std::make_tuple(320,4,64,96,3,3), std::make_tuple(640,640,96,128,3,3), - std::make_tuple(320,1280,26,38,3,3), std::make_tuple(1280,1280,26,38,1,1), std::make_tuple(256,128,768,1024,3,3), std::make_tuple(128,3,768,1024,3,3), @@ -385,14 +384,13 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // // float diff = fabs(conv2d_data[i] - wino_data[i]); - // float diff = fabs(im2col_data[i] - wino_data[i]); - // float diff1 = fabs(im2col_data[i] - conv2d_data[i]); + // // for(int i = 26*38; i < 2*26*38; i++) { + // // for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(im2col_data[i] - conv2d_data[i]); // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n", + // printf("(%7.3f, %7.3f, %.2f, %d) \n", // im2col_data[i], conv2d_data[i], - // wino_data[i], diff, diff1, i); + // diff, i); // // break; // // } // }