added split-k mode to tensor core path

This commit is contained in:
bssrdf 2025-11-10 14:38:23 -05:00
parent a428feecdd
commit 15daa5a6a8
3 changed files with 157 additions and 33 deletions

View File

@ -13,18 +13,19 @@ constexpr uint WARPSIZE = 32;
//currently not use; in future for split-k kernels
// static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) {
// const int row = blockIdx.x;
// const int col = threadIdx.x;
template <typename src_T, typename dst_T>
static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restrict__ dst, const int ncols, const int nrows) {
const int row = blockIdx.x;
const int col = threadIdx.x;
// float sum = 0.0f;
// if (row * blockDim.x + col < ncols) {
// for (int i = 0; i < nrows; ++i){
// sum += x[i * ncols + row * blockDim.x + col];
// }
// dst[row * blockDim.x + col] = sum;
// }
// }
float sum = 0.0f;
if (row * blockDim.x + col < ncols) {
for (int i = 0; i < nrows; ++i){
sum += ggml_cuda_cast<float>(x[i * ncols + row * blockDim.x + col]);
}
dst[row * blockDim.x + col] = ggml_cuda_cast<dst_T>(sum);
}
}
template <typename src_T, typename dst_T>
static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
@ -682,7 +683,7 @@ __device__ __forceinline__ void ldmatrix_b(
}
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
const int WK, const int NUM_THREADS>
const int WK, const int ksplit, const int NUM_THREADS>
static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
const half * __restrict__ kernel,
T * __restrict__ output,
@ -699,12 +700,19 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
const uint inChannelOffset = param.c * param.w;
const uint inDepthOffset = param.h * param.c * param.w;
const uint inNOffset = param.c * param.w * param.h * param.d;
const unsigned int z = blockIdx.z;
// loop bounds, constexpr where possible allows for loop unrolling
constexpr unsigned int mma_tiles_per_warp_k = 4;
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
const unsigned int num_block_tiles_k = (K + (BK-1)) / BK;
const unsigned int ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
const unsigned int start_k = (ksplit > 0) ? z * ks : 0;
const unsigned int end_k = min(start_k + ks, weightKOffset);
const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK;
// const unsigned int num_block_tiles_k = (K + (BK-1)) / BK;
// calculate block/warp indices
const unsigned int block_m = blockIdx.y;
@ -750,8 +758,8 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
const half* A_block_gmem = input;
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inNOffset, inDepthOffset, inChannelOffset, param);
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, start_k, end_k, inNOffset, inDepthOffset, inChannelOffset, param);
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, start_k, end_k, weightKOffset, param);
int offset_direction = 1;
@ -761,9 +769,10 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
if (block_k != num_block_tiles_k){
const half* A_block_gmem = input;
const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK,
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k,
inNOffset, inDepthOffset, inChannelOffset, param);
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k,
weightKOffset, param);
}
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
half* B_warp_tile = B_block_smem + (warp_n * WN * BK);
@ -852,12 +861,28 @@ static __global__ void conv3d_implicit_kernel(const half * __restrict__ input,
uint32_t dst_ptr = *(reinterpret_cast<uint32_t*>(&smemoutput[idx+j*32*BN/2]));
half (&res_)[2] = reinterpret_cast<half(&)[2]>(dst_ptr);
if(n < param.n && row < param.k && col < PQZ){
const uint outOffset = (n * param.k + row) * PQZ + col;
// if constexpr (ksplit > 0) {
// const uint outOffset = (n * param.k + row) * PQZ + col;
// output[outOffset] = ggml_cuda_cast<T>(res_[0]);
// } else {
// const uint outOffset = (n * param.k + row) * PQZ + col;
// output[outOffset] = ggml_cuda_cast<T>(res_[0]);
// }
const uint outOffset = ksplit > 0 ? (z * param.n * param.k + n * param.k + row) * PQZ + col :
(n * param.k + row) * PQZ + col;
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
}
if(n < param.n && row+1 < param.k && col < PQZ){
const uint outOffset = (n * param.k + row + 1) * PQZ + col;
const uint outOffset = ksplit > 0 ? (z * param.n * param.k + n * param.k + row+1) * PQZ + col :
(n * param.k + row+1) * PQZ + col;
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
// if constexpr (ksplit > 0) {
// const uint outOffset = (n * param.k + row) * PQZ + col;
// output[outOffset] = ggml_cuda_cast<T>(res_[0]);
// } else {
// const uint outOffset = (n * param.k + row + 1) * PQZ + col;
// output[outOffset] = ggml_cuda_cast<T>(res_[1]);
// }
}
}
}
@ -914,6 +939,33 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
WNITER, TM, TN, NUM_THREADS, 1, false, 0><<<grid, thblock, 0, st>>>(X_D, K_D, Y_D, P);
}
template<const int BM, const int BN, const int BK,
const int WM, const int WN, const int WK, const int ksplit,
const unsigned int ThreadsM, const unsigned int ThreadsN,
const int NUM_THREADS>
static void launch_conv3d_implicit_split_kernel(ggml_backend_cuda_context & ctx, const half *X_H, const half *K_H, float *Y_D,
const unsigned int BlocksM, const unsigned int BlocksN,
const unsigned int shmem_bytes,
const param_t P, cudaStream_t st){
int id = ggml_cuda_get_device();
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), ksplit * P.k * P.Od * P.Oh * P.Ow * P.n);
cudaFuncSetAttribute(conv3d_implicit_kernel<half, BM, BN, BK, WM, WN, WK, ksplit, NUM_THREADS>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM, ksplit);
dim3 blockDim(ThreadsN, ThreadsM);
conv3d_implicit_kernel<half, BM, BN, BK, WM, WN, WK, ksplit, NUM_THREADS>
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
const unsigned int nrows = P.n * P.k * P.Oh * P.Ow * P.Od;
const unsigned int blockx = (nrows + 511) / 512;
const dim3 block_nums(blockx, 1, 1);
const dim3 block_dims(512, 1, 1);
reduce_f32<half, float><<<block_nums, block_dims, 0, st>>>(Y_H.get(), Y_D, nrows, ksplit);
}
static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) {
@ -956,6 +1008,9 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
static_assert(WN_dim % 4 == 0, "final output requires this to be bank conflicts free");
const unsigned int BlocksM = (P.n * P.Oh * P.Ow * P.Od + BM_dim - 1) / BM_dim;
const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
@ -963,13 +1018,73 @@ static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
cudaFuncSetAttribute(conv3d_implicit_kernel<float, BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
// if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
if (BlocksM * BlocksN < 2*(unsigned int)nsm){
int j, max_remaining_waves = -1, candidate = -1;
int ks = min(12, nsm / (BlocksM * BlocksN));
if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5)
ks = 12;
for (j = 2; j <= ks; j++){
const int remainder = (BlocksM * BlocksN * j) % nsm;
if ((P.c * P.r * P.s * P.t) % (8*j) == 0){
if (remainder == 0) {
candidate = j;
max_remaining_waves = 0;
break;
} else if (remainder > max_remaining_waves) {
max_remaining_waves = remainder;
candidate = j;
}
}
}
if(candidate != -1){
j = candidate;
if (j == 2) {
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 2,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 3) {
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 3,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 4) {
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 4,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 5) {
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 5,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 6) {
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 6,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 7) {
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 7,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 8) {
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 8,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 9) {
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 9,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 10) {
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 10,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 11) {
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 11,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
} else if (j == 12) {
launch_conv3d_implicit_split_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 12,
ThreadsM, ThreadsN, NumThreads>(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st);
}
return;
}
}
cudaFuncSetAttribute(conv3d_implicit_kernel<float, BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 0, NumThreads>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM);
dim3 blockDim(ThreadsN, ThreadsM);
conv3d_implicit_kernel<float, BM_dim, BN_dim, BK_dim,
WM_dim, WN_dim, WK_dim, NumThreads>
WM_dim, WN_dim, WK_dim, 0, NumThreads>
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_D, P);
} else{
conv3d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);

View File

@ -65,6 +65,8 @@ unsigned int NUM_THREADS>
__device__ __forceinline__ void tileMemcpySwizzleB(
const half* src,
half* dst,
const unsigned int start_k,
const unsigned int end_k,
const unsigned int src_stride,
param_t param
){
@ -90,7 +92,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
const unsigned int kidx = thread_col*8;
const unsigned int kidx = start_k + thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
const int curC = curIdx.x;
const int curT = curIdx.y;
@ -104,7 +106,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
// TODO: move some checks outside of loop?
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c && kidx < end_k){
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
@ -127,7 +129,8 @@ unsigned int NUM_THREADS>
__device__ __forceinline__ void tileMemcpySwizzleA(
const half* src,
half* dst,
// const unsigned int src_stride,
const unsigned int start_k,
const unsigned int end_k,
const unsigned int inNOffset,
const unsigned int inDepthOffset,
const unsigned int inChannelOffset,
@ -157,7 +160,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
const unsigned int kidx = thread_col*8;
const unsigned int kidx = start_k+thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
#pragma unroll
@ -177,7 +180,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c && kidx < end_k){
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
} else{
@ -203,6 +206,8 @@ __device__ __forceinline__ void tileMemcpyLoadA(
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
// const unsigned int src_stride,
const unsigned int block_k,
const unsigned int start_k,
const unsigned int end_k,
const unsigned int inNOffset,
const unsigned int inDepthOffset,
const unsigned int inChannelOffset,
@ -227,7 +232,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
const unsigned int kidx = block_k + thread_col*8;
const unsigned int kidx = start_k + block_k + thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
#pragma unroll
@ -243,7 +248,8 @@ __device__ __forceinline__ void tileMemcpyLoadA(
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c){
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d
&& curC < param.c && kidx < end_k){
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
} else{
@ -270,6 +276,8 @@ __device__ __forceinline__ void tileMemcpyLoadB(
const half* src,
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
const unsigned int block_k,
const unsigned int start_k,
const unsigned int end_k,
const unsigned int src_stride,
param_t param
){
@ -292,7 +300,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
const unsigned int kidx = block_k + thread_col*8;
const unsigned int kidx = start_k + block_k + thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
const int curC = curIdx.x;
const int curT = curIdx.y;
@ -300,9 +308,10 @@ __device__ __forceinline__ void tileMemcpyLoadB(
const int curS = curIdx.w;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
const unsigned int src_index = thread_row * src_stride + kidx;
// TODO : move some checks outside of the loop
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t
&& curC < param.c && kidx < end_k){
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);

View File

@ -350,9 +350,9 @@ int main(void)
{
ggml_time_init();
std::vector<std::tuple<int, int, int, int, int, int, int, int>> configs = {
// std::make_tuple(1,2,16,32,4,3,3,3),
std::make_tuple(1,2,16,32,4,3,3,3),
// std::make_tuple(320,1280,26,38,8,3,3,3),
std::make_tuple(1280,1280,26,38,8,3,3,3),
// std::make_tuple(1280,1280,26,38,8,3,3,3),
// std::make_tuple(320,1280,52,76,8,3,3,3),
// std::make_tuple(1280,1280,52,76,8,3,3,3),
// std::make_tuple(320,1280,104,152,8,3,3,3),