diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cu b/ggml/src/ggml-cuda/conv3d-implicit.cu index 89e9ebf2a6..fd6e1d0b71 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cu +++ b/ggml/src/ggml-cuda/conv3d-implicit.cu @@ -13,18 +13,19 @@ constexpr uint WARPSIZE = 32; //currently not use; in future for split-k kernels -// static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { -// const int row = blockIdx.x; -// const int col = threadIdx.x; +template +static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restrict__ dst, const int ncols, const int nrows) { + const int row = blockIdx.x; + const int col = threadIdx.x; -// float sum = 0.0f; -// if (row * blockDim.x + col < ncols) { -// for (int i = 0; i < nrows; ++i){ -// sum += x[i * ncols + row * blockDim.x + col]; -// } -// dst[row * blockDim.x + col] = sum; -// } -// } + float sum = 0.0f; + if (row * blockDim.x + col < ncols) { + for (int i = 0; i < nrows; ++i){ + sum += ggml_cuda_cast(x[i * ncols + row * blockDim.x + col]); + } + dst[row * blockDim.x + col] = ggml_cuda_cast(sum); + } +} template static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ @@ -682,7 +683,7 @@ __device__ __forceinline__ void ldmatrix_b( } template + const int WK, const int ksplit, const int NUM_THREADS> static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, const half * __restrict__ kernel, T * __restrict__ output, @@ -699,12 +700,19 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, const uint inChannelOffset = param.c * param.w; const uint inDepthOffset = param.h * param.c * param.w; const uint inNOffset = param.c * param.w * param.h * param.d; + const unsigned int z = blockIdx.z; // loop bounds, constexpr where possible allows for loop unrolling constexpr unsigned int mma_tiles_per_warp_k = 4; constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; - const unsigned int num_block_tiles_k = (K + (BK-1)) / BK; + + const unsigned int ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; + const unsigned int start_k = (ksplit > 0) ? z * ks : 0; + const unsigned int end_k = min(start_k + ks, weightKOffset); + const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK; + +// const unsigned int num_block_tiles_k = (K + (BK-1)) / BK; // calculate block/warp indices const unsigned int block_m = blockIdx.y; @@ -750,8 +758,8 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, const half* A_block_gmem = input; const half* B_block_gmem = kernel + block_n * BN * weightKOffset; - tileMemcpySwizzleA(A_block_gmem, A_block_smem, inNOffset, inDepthOffset, inChannelOffset, param); - tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, start_k, end_k, inNOffset, inDepthOffset, inChannelOffset, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, start_k, end_k, weightKOffset, param); int offset_direction = 1; @@ -761,9 +769,10 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, if (block_k != num_block_tiles_k){ const half* A_block_gmem = input; const half* B_block_gmem = kernel + (block_n * BN * weightKOffset); - tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inNOffset, inDepthOffset, inChannelOffset, param); - tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param); + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, + weightKOffset, param); } half* A_warp_tile = A_block_smem + (warp_m * WM * BK); half* B_warp_tile = B_block_smem + (warp_n * WN * BK); @@ -852,12 +861,28 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input, uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*32*BN/2])); half (&res_)[2] = reinterpret_cast(dst_ptr); if(n < param.n && row < param.k && col < PQZ){ - const uint outOffset = (n * param.k + row) * PQZ + col; + // if constexpr (ksplit > 0) { + // const uint outOffset = (n * param.k + row) * PQZ + col; + // output[outOffset] = ggml_cuda_cast(res_[0]); + // } else { + // const uint outOffset = (n * param.k + row) * PQZ + col; + // output[outOffset] = ggml_cuda_cast(res_[0]); + // } + const uint outOffset = ksplit > 0 ? (z * param.n * param.k + n * param.k + row) * PQZ + col : + (n * param.k + row) * PQZ + col; output[outOffset] = ggml_cuda_cast(res_[0]); } if(n < param.n && row+1 < param.k && col < PQZ){ - const uint outOffset = (n * param.k + row + 1) * PQZ + col; + const uint outOffset = ksplit > 0 ? (z * param.n * param.k + n * param.k + row+1) * PQZ + col : + (n * param.k + row+1) * PQZ + col; output[outOffset] = ggml_cuda_cast(res_[1]); + // if constexpr (ksplit > 0) { + // const uint outOffset = (n * param.k + row) * PQZ + col; + // output[outOffset] = ggml_cuda_cast(res_[0]); + // } else { + // const uint outOffset = (n * param.k + row + 1) * PQZ + col; + // output[outOffset] = ggml_cuda_cast(res_[1]); + // } } } } @@ -914,6 +939,33 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, WNITER, TM, TN, NUM_THREADS, 1, false, 0><<>>(X_D, K_D, Y_D, P); } +template +static void launch_conv3d_implicit_split_kernel(ggml_backend_cuda_context & ctx, const half *X_H, const half *K_H, float *Y_D, + const unsigned int BlocksM, const unsigned int BlocksN, + const unsigned int shmem_bytes, + const param_t P, cudaStream_t st){ + + int id = ggml_cuda_get_device(); + + ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Od * P.Oh * P.Ow * P.n); + cudaFuncSetAttribute(conv3d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM, ksplit); + dim3 blockDim(ThreadsN, ThreadsM); + + conv3d_implicit_kernel + <<>>(X_H, K_H, Y_H.get(), P); + + const unsigned int nrows = P.n * P.k * P.Oh * P.Ow * P.Od; + const unsigned int blockx = (nrows + 511) / 512; + const dim3 block_nums(blockx, 1, 1); + const dim3 block_dims(512, 1, 1); + reduce_f32<<>>(Y_H.get(), Y_D, nrows, ksplit); +} + static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) { @@ -956,6 +1008,9 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; + + static_assert(WN_dim % 4 == 0, "final output requires this to be bank conflicts free"); + const unsigned int BlocksM = (P.n * P.Oh * P.Ow * P.Od + BM_dim - 1) / BM_dim; const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; @@ -963,13 +1018,73 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa constexpr unsigned int NumThreads = ThreadsM * ThreadsN; const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); - cudaFuncSetAttribute(conv3d_implicit_kernel, + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { + if (BlocksM * BlocksN < 2*(unsigned int)nsm){ + int j, max_remaining_waves = -1, candidate = -1; + int ks = min(12, nsm / (BlocksM * BlocksN)); + if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5) + ks = 12; + for (j = 2; j <= ks; j++){ + const int remainder = (BlocksM * BlocksN * j) % nsm; + if ((P.c * P.r * P.s * P.t) % (8*j) == 0){ + if (remainder == 0) { + candidate = j; + max_remaining_waves = 0; + break; + } else if (remainder > max_remaining_waves) { + max_remaining_waves = remainder; + candidate = j; + } + } + } + + if(candidate != -1){ + j = candidate; + if (j == 2) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 3) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 4) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 5) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 6) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 7) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 8) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 9) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 10) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 11) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 12) { + launch_conv3d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } + return; + } + } + cudaFuncSetAttribute(conv3d_implicit_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 dim3 gridDim(BlocksN, BlocksM); dim3 blockDim(ThreadsN, ThreadsM); conv3d_implicit_kernel + WM_dim, WN_dim, WK_dim, 0, NumThreads> <<>>(X_H, K_H, Y_D, P); } else{ conv3d_implicit_cuda(X_D, K_D, Y_D, P, st); diff --git a/ggml/src/ggml-cuda/conv3d-implicit.cuh b/ggml/src/ggml-cuda/conv3d-implicit.cuh index 37449f677e..d7d8ef1086 100644 --- a/ggml/src/ggml-cuda/conv3d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv3d-implicit.cuh @@ -65,6 +65,8 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleB( const half* src, half* dst, + const unsigned int start_k, + const unsigned int end_k, const unsigned int src_stride, param_t param ){ @@ -90,7 +92,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int kidx = thread_col*8; + const unsigned int kidx = start_k + thread_col*8; const int4 curIdx = inputIndices<0>(kidx, param); const int curC = curIdx.x; const int curT = curIdx.y; @@ -104,7 +106,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); // TODO: move some checks outside of loop? - if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c && kidx < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -127,7 +129,8 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleA( const half* src, half* dst, - // const unsigned int src_stride, + const unsigned int start_k, + const unsigned int end_k, const unsigned int inNOffset, const unsigned int inDepthOffset, const unsigned int inChannelOffset, @@ -157,7 +160,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int kidx = thread_col*8; + const unsigned int kidx = start_k+thread_col*8; const int4 curIdx = inputIndices<0>(kidx, param); #pragma unroll @@ -177,7 +180,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){ + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c && kidx < end_k){ int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; dst_float4[dst_index] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{ @@ -203,6 +206,8 @@ __device__ __forceinline__ void tileMemcpyLoadA( float4 (&dst_reg)[ELEMENTS_PER_THREAD], // const unsigned int src_stride, const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, const unsigned int inNOffset, const unsigned int inDepthOffset, const unsigned int inChannelOffset, @@ -227,7 +232,7 @@ __device__ __forceinline__ void tileMemcpyLoadA( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int kidx = block_k + thread_col*8; + const unsigned int kidx = start_k + block_k + thread_col*8; const int4 curIdx = inputIndices<0>(kidx, param); #pragma unroll @@ -243,7 +248,8 @@ __device__ __forceinline__ void tileMemcpyLoadA( const int curH = posh_ori + curIdx.z * param.dilation1; // input h const int curW = posw_ori + curIdx.w * param.dilation0; // input w const int curC = curIdx.x; - if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){ + if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d + && curC < param.c && kidx < end_k){ int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC; dst_reg[i] = reinterpret_cast(&src[n * inNOffset + inOffsetTmp])[0]; } else{ @@ -270,6 +276,8 @@ __device__ __forceinline__ void tileMemcpyLoadB( const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, const unsigned int src_stride, param_t param ){ @@ -292,7 +300,7 @@ __device__ __forceinline__ void tileMemcpyLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int kidx = block_k + thread_col*8; + const unsigned int kidx = start_k + block_k + thread_col*8; const int4 curIdx = inputIndices<0>(kidx, param); const int curC = curIdx.x; const int curT = curIdx.y; @@ -300,9 +308,10 @@ __device__ __forceinline__ void tileMemcpyLoadB( const int curS = curIdx.w; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8; + const unsigned int src_index = thread_row * src_stride + kidx; // TODO : move some checks outside of the loop - if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t + && curC < param.c && kidx < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); diff --git a/tests/test-conv3d.cpp b/tests/test-conv3d.cpp index 6483841a42..e05bc50090 100644 --- a/tests/test-conv3d.cpp +++ b/tests/test-conv3d.cpp @@ -350,9 +350,9 @@ int main(void) { ggml_time_init(); std::vector> configs = { - // std::make_tuple(1,2,16,32,4,3,3,3), + std::make_tuple(1,2,16,32,4,3,3,3), // std::make_tuple(320,1280,26,38,8,3,3,3), - std::make_tuple(1280,1280,26,38,8,3,3,3), + // std::make_tuple(1280,1280,26,38,8,3,3,3), // std::make_tuple(320,1280,52,76,8,3,3,3), // std::make_tuple(1280,1280,52,76,8,3,3,3), // std::make_tuple(320,1280,104,152,8,3,3,3),