added split-k mode for skinny mnk shapes
This commit is contained in:
parent
275c08d25d
commit
6f44f47113
|
|
@ -13,18 +13,19 @@ constexpr uint WARPSIZE = 32;
|
|||
|
||||
|
||||
//currently not use; in future for split-k kernels
|
||||
// static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
// const int row = blockIdx.x;
|
||||
// const int col = threadIdx.x;
|
||||
template <typename src_T, typename dst_T>
|
||||
static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restrict__ dst, const int ncols, const int nrows) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
// float sum = 0.0f;
|
||||
// if (row * blockDim.x + col < ncols) {
|
||||
// for (int i = 0; i < nrows; ++i){
|
||||
// sum += x[i * ncols + row * blockDim.x + col];
|
||||
// }
|
||||
// dst[row * blockDim.x + col] = sum;
|
||||
// }
|
||||
// }
|
||||
float sum = 0.0f;
|
||||
if (row * blockDim.x + col < ncols) {
|
||||
for (int i = 0; i < nrows; ++i){
|
||||
sum += ggml_cuda_cast<float>(x[i * ncols + row * blockDim.x + col]);
|
||||
}
|
||||
dst[row * blockDim.x + col] = ggml_cuda_cast<dst_T>(sum);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename src_T, typename dst_T>
|
||||
static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
|
||||
|
|
@ -705,26 +706,32 @@ __device__ __forceinline__ void ldmatrix_b(
|
|||
}
|
||||
|
||||
template<const int BM, const int BN, const int BK, const int WM, const int WN,
|
||||
const int WK, const int NUM_THREADS>
|
||||
const int WK, const int ksplit, const int NUM_THREADS>
|
||||
static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||
const half * __restrict__ kernel,
|
||||
half * __restrict__ output,
|
||||
const param_t param) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
|
||||
constexpr unsigned int MMA_M = 16;
|
||||
constexpr unsigned int MMA_N = 8;
|
||||
|
||||
constexpr unsigned int MMA_M = 16;
|
||||
constexpr unsigned int MMA_N = 8;
|
||||
|
||||
const unsigned int K = param.c * param.r * param.s;
|
||||
const uint inChannelOffset = param.c * param.w;
|
||||
const uint weightKOffset = param.c * param.r * param.s;
|
||||
const uint weightKOffset = K;
|
||||
|
||||
// loop bounds, constexpr where possible allows for loop unrolling
|
||||
constexpr unsigned int mma_tiles_per_warp_k = 4;
|
||||
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
|
||||
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
|
||||
const unsigned int num_block_tiles_k = (K + (BK-1)) / BK;
|
||||
const unsigned int z = blockIdx.z;
|
||||
|
||||
const unsigned int ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
|
||||
const unsigned int start_k = (ksplit > 0) ? z * ks : 0;
|
||||
const unsigned int end_k = min(start_k + ks, weightKOffset);
|
||||
const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK;
|
||||
|
||||
|
||||
|
||||
// calculate block/warp indices
|
||||
const unsigned int block_m = blockIdx.y;
|
||||
|
|
@ -770,8 +777,8 @@ constexpr unsigned int MMA_N = 8;
|
|||
|
||||
const half* A_block_gmem = input;
|
||||
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, inChannelOffset, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, weightKOffset, param);
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, start_k, end_k, inChannelOffset, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, start_k, end_k, weightKOffset, param);
|
||||
|
||||
int offset_direction = 1;
|
||||
|
||||
|
|
@ -781,8 +788,8 @@ constexpr unsigned int MMA_N = 8;
|
|||
if (block_k != num_block_tiles_k){
|
||||
const half* A_block_gmem = input;
|
||||
const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
|
||||
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, inChannelOffset, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, weightKOffset, param);
|
||||
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inChannelOffset, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, weightKOffset, param);
|
||||
}
|
||||
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
|
||||
half* B_warp_tile = B_block_smem + (warp_n * WN * BK);
|
||||
|
|
@ -813,6 +820,8 @@ constexpr unsigned int MMA_N = 8;
|
|||
}
|
||||
|
||||
|
||||
|
||||
|
||||
if (block_k != num_block_tiles_k)
|
||||
{
|
||||
// switch smem buffers each iteration
|
||||
|
|
@ -863,11 +872,18 @@ constexpr unsigned int MMA_N = 8;
|
|||
const uint gemm_i = n_idx + j*32;
|
||||
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
if(n < param.n && row < param.k && col < param.Oh * param.Ow){
|
||||
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||
if (n < param.n && row < param.k && col < param.Oh * param.Ow) {
|
||||
uint idx = output_lds_addr + subk + j*32*BN/2;
|
||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||
output[outOffset] = smemoutput[idx];
|
||||
if constexpr (ksplit > 0) {
|
||||
const uint outOffset = z * param.n * param.k * param.Oh * param.Ow +
|
||||
n * param.k * param.Oh * param.Ow +
|
||||
row * param.Oh * param.Ow + col;
|
||||
output[outOffset] = smemoutput[idx];
|
||||
} else {
|
||||
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||
output[outOffset] = smemoutput[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -952,7 +968,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
|
||||
const half *X_H = input_f16.get();
|
||||
const half *K_H = kernel_f16.get();
|
||||
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
|
||||
|
||||
constexpr unsigned int BM_dim = 256;
|
||||
constexpr unsigned int BN_dim = 256;
|
||||
|
|
@ -972,16 +987,41 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
|
||||
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
||||
|
||||
cudaFuncSetAttribute(conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
|
||||
dim3 gridDim(BlocksN, BlocksM);
|
||||
dim3 blockDim(ThreadsN, ThreadsM);
|
||||
const unsigned int K2MN = 8;
|
||||
|
||||
conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim,
|
||||
WM_dim, WN_dim, WK_dim, NumThreads>
|
||||
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
|
||||
if (P.c * P.r * P.s > K2MN * P.n * P.Oh * P.Ow || P.c * P.r * P.s > K2MN * P.k) {
|
||||
const unsigned int ksplit = 8;
|
||||
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n);
|
||||
|
||||
cudaFuncSetAttribute(conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit, NumThreads>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
|
||||
dim3 gridDim(BlocksN, BlocksM, ksplit);
|
||||
dim3 blockDim(ThreadsN, ThreadsM);
|
||||
|
||||
conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim,
|
||||
WM_dim, WN_dim, WK_dim, ksplit, NumThreads>
|
||||
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
|
||||
|
||||
const unsigned int nrows = P.n * P.k * P.Oh * P.Ow;
|
||||
const unsigned int blockx = (nrows + 511) / 512;
|
||||
const dim3 block_nums(blockx, 1, 1);
|
||||
const dim3 block_dims(512, 1, 1);
|
||||
reduce_f32<half, float><<<block_nums, block_dims, 0, st>>>(Y_H.get(), Y_D, nrows, ksplit);
|
||||
|
||||
} else {
|
||||
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
|
||||
|
||||
cudaFuncSetAttribute(conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 0, NumThreads>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
|
||||
dim3 gridDim(BlocksN, BlocksM);
|
||||
dim3 blockDim(ThreadsN, ThreadsM);
|
||||
|
||||
conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim,
|
||||
WM_dim, WN_dim, WK_dim, 0, NumThreads>
|
||||
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
|
||||
}
|
||||
} else{
|
||||
conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ unsigned int NUM_THREADS>
|
|||
__device__ __forceinline__ void tileMemcpySwizzleB(
|
||||
const half* src,
|
||||
half* dst,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int src_stride,
|
||||
param_t param
|
||||
){
|
||||
|
|
@ -57,9 +59,9 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
|
|
@ -68,7 +70,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
@ -91,7 +93,8 @@ unsigned int NUM_THREADS>
|
|||
__device__ __forceinline__ void tileMemcpySwizzleA(
|
||||
const half* src,
|
||||
half* dst,
|
||||
// const unsigned int src_stride,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int inChannelOffset,
|
||||
param_t param
|
||||
)
|
||||
|
|
@ -128,9 +131,9 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
int curH = posh_ori + curR * param.d_h; // input h
|
||||
int curW = posw_ori + curS * param.d_w; // input w
|
||||
// apply swizzle to the dst index
|
||||
|
|
@ -138,7 +141,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
curR < param.r && curS < param.s && curC < param.c){
|
||||
curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){
|
||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||
} else{
|
||||
|
|
@ -164,6 +167,8 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
// const unsigned int src_stride,
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int inChannelOffset,
|
||||
param_t param
|
||||
){
|
||||
|
|
@ -194,13 +199,13 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
int curH = posh_ori + curR * param.d_h; // input h
|
||||
int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
curR < param.r && curS < param.s && curC < param.c){
|
||||
curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){
|
||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||
} else{
|
||||
|
|
@ -227,6 +232,8 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
const half* src,
|
||||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int src_stride,
|
||||
param_t param
|
||||
){
|
||||
|
|
@ -249,14 +256,14 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c){
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
|
|||
|
|
@ -309,7 +309,6 @@ int main(void)
|
|||
std::make_tuple(4,320,64,96,3,3),
|
||||
std::make_tuple(320,4,64,96,3,3),
|
||||
std::make_tuple(640,640,96,128,3,3),
|
||||
std::make_tuple(320,1280,26,38,3,3),
|
||||
std::make_tuple(1280,1280,26,38,1,1),
|
||||
std::make_tuple(256,128,768,1024,3,3),
|
||||
std::make_tuple(128,3,768,1024,3,3),
|
||||
|
|
@ -385,14 +384,13 @@ int main(void)
|
|||
|
||||
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
|
||||
// for(int i = 0; i < 26*38; i++) {
|
||||
// for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
// // float diff = fabs(conv2d_data[i] - wino_data[i]);
|
||||
// float diff = fabs(im2col_data[i] - wino_data[i]);
|
||||
// float diff1 = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// // for(int i = 26*38; i < 2*26*38; i++) {
|
||||
// // for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// // if(diff > 0.5) {
|
||||
// printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n",
|
||||
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
|
||||
// im2col_data[i], conv2d_data[i],
|
||||
// wino_data[i], diff, diff1, i);
|
||||
// diff, i);
|
||||
// // break;
|
||||
// // }
|
||||
// }
|
||||
|
|
|
|||
Loading…
Reference in New Issue