diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 0a5c370f29..0b410a460a 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -1,8 +1,11 @@ -#include "conv2d-implicit.cuh" +// #include +#include "ggml.h" +#include "common.cuh" #include "convert.cuh" - +#include "conv2d-implicit.cuh" static const int WARPSIZE = 32; // warpSize is not constexpr +typedef unsigned int uint; static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) { const int row = blockIdx.x; @@ -20,7 +23,8 @@ static __global__ void reduce_f32(const float * __restrict__ x, float * __restri template + // layout: 0, NHWC; 1, NCHW + const int layout, const bool vec_load, const int ksplit, const int PAD=4> static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const T * __restrict__ kernel, float * __restrict__ output, @@ -76,7 +80,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // int inOffset = (ksplit > 0): z * param.c * param.h * param.w ; // int weiOffset = (by * BN + tx / 8 * 4) * param.c * param.r * param.s; - int inChannelOffset = param.c * param.w; + int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w; // int weightChannelOffset = param.r * param.s; int weightKOffset = param.c * param.r * param.s; @@ -125,16 +129,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, if(vec_load){ // if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < param.c * param.r * param.s){ if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < end_k){ - if constexpr (std::is_same_v(T, float)){ - float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; + if constexpr (std::is_same_v){ + float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; smemweight[weight_sts_addr + offset + 0] = tmp.x; smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y; smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; }else{ // read 4 halves // half val[4]; - float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; - half *val = reinterpret_cast(&tmp); + float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0]; + const half *val = reinterpret_cast(&tmp); // val[1] = reinterpret_cast(&tmp.y); smemweight[weight_sts_addr + offset + 0] = val[0]; smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1]; @@ -177,15 +181,28 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; int inOffset = n * param.c * param.h * param.w ; if(vec_load){ - const uint curR = fastdiv(start_k + innerColA * 4, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - + // const uint curR = fastdiv(start_k + innerColA * 4, param.SC_fastdiv); // channel offset + // const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint cur0 = fastdiv(start_k + innerColA * 4, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; const int curH = posh_ori + curR; // input h const int curW = posw_ori + curS; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){ - int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; smeminput[input_sts_addr + offset + 0] = tmp.x; smeminput[input_sts_addr + offset + BM] = tmp.y; smeminput[input_sts_addr + offset + 2*BM] = tmp.z; @@ -198,14 +215,27 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } else { #pragma unroll for (int i = 0; i < 4; ++i){ - const uint curR = fastdiv(start_k + innerColA * 4 + i, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - + // const uint curR = fastdiv(start_k + innerColA * 4 + i, param.SC_fastdiv); // channel offset + // const uint curS = fastdiv(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curC = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint cur0 = fastdiv(start_k + innerColA * 4 + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; const int curH = posh_ori + curR; // input h const int curW = posw_ori + curS; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){ - int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; smeminput[input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp]; } else { smeminput[input_sts_addr + offset + i*BM] = 0.f; @@ -398,15 +428,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { if(vec_load){ if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK < end_k){ - if constexpr (std::is_same_v(T, float)){ - float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; + if constexpr (std::is_same_v){ + float4 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = tmp.x; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = tmp.y; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; } else { - float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; - half *val = reinterpret_cast(&tmp); + float2 tmp = reinterpret_cast(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0]; + const half *val = reinterpret_cast(&tmp); smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = val[0]; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = val[1]; smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; @@ -437,15 +467,29 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q; int inOffset = n * param.c * param.h * param.w ; if(vec_load){ - const uint curR = fastdiv(innerColA * 4 + crs + BK, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curR = fastdiv(innerColA * 4 + crs + BK, param.SC_fastdiv); // channel offset + // const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint cur0 = fastdiv(innerColA * 4 + crs + BK, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; const int curH = posh_ori + curR; // input h const int curW = posw_ori + curS; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){ - int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; smeminput[write_flag * BM * BK + input_sts_addr + offset + 0] = tmp.x; smeminput[write_flag * BM * BK + input_sts_addr + offset + BM] = tmp.y; smeminput[write_flag * BM * BK + input_sts_addr + offset + 2*BM] = tmp.z; @@ -458,14 +502,28 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } else { #pragma unroll for (int i = 0; i < 4; ++i){ - const uint curR = fastdiv(innerColA * 4 + crs + BK + i, param.SC_fastdiv); // channel offset - const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curR = fastdiv(innerColA * 4 + crs + BK + i, param.SC_fastdiv); // channel offset + // const uint curS = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const uint curC = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const uint curC = layout == 0 ? cur2 : cur0; + const uint curR = layout == 0 ? cur0 : cur1; + const uint curS = layout == 0 ? cur1 : cur2; const int curH = posh_ori + curR; // input h const int curW = posw_ori + curS; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){ - int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + // int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp]; } else { smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f; @@ -684,26 +742,37 @@ constexpr static int conv_shapes[][NUM_VARIANTS] = { { 256, 256, 128} // NUM_THREADS }; -template +template static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { - int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number - int blocky = (P.k + 127) / 128; // blocky number + + const uint BM = conv_shapes[0][CONV_SHAPE]; + const uint BN = conv_shapes[1][CONV_SHAPE]; + const uint BK = conv_shapes[2][CONV_SHAPE]; + const uint WM = conv_shapes[3][CONV_SHAPE]; + const uint WN = conv_shapes[4][CONV_SHAPE]; + const uint WNITER = conv_shapes[5][CONV_SHAPE]; + const uint TM = conv_shapes[6][CONV_SHAPE]; + const uint TN = conv_shapes[7][CONV_SHAPE]; + const uint NUM_THREADS = conv_shapes[8][CONV_SHAPE]; + int blockx = ((P.Oh * P.Ow + BM - 1) / BM); // blockx number + int blocky = (P.k + BN-1) / BN; // blocky number int blockz = P.n; // blockz number - int threadx = CUDA_CONV2D_IMPLICT_BLOCK_SIZE; // threadx number per block + // int threadx = NUM; // threadx number per block int thready = 1; // thready number per block int threadz = 1; // threadz number per block - dim3 thblock(threadx, thready, threadz); + dim3 thblock(NUM_THREADS, thready, threadz); dim3 grid(blockx, blocky, blockz); - int smem_size = 24 * 1024; - conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); + // int smem_size = 24 * 1024; + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); } static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) { - conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) { - conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); } void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -745,9 +814,12 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t total = B * OC * OH * OW; param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; - params.SC_fastdiv = init_fastdiv_values(KW*KH); + params.SC_fastdiv = init_fastdiv_values(KW*IC); params.OW_fastdiv = init_fastdiv_values(OW); params.C_fastdiv = init_fastdiv_values(IC); + params.RS_fastdiv = init_fastdiv_values(KW*KH); + params.S_fastdiv = init_fastdiv_values(KW); + params.nchw = false; if (kernel->type == GGML_TYPE_F16) { conv2d_implicit_cuda_f16(X_D, (half *) K_D, Y_D, params, st); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 4fe6134873..d2f3cffcc3 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -17,9 +17,12 @@ typedef struct{ unsigned int d_w; //dilation width unsigned int Oh; //output height unsigned int Ow; //output width + bool nchw; uint3 SC_fastdiv; uint3 OW_fastdiv; uint3 C_fastdiv; + uint3 RS_fastdiv; + uint3 S_fastdiv; } param_t; diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp index e963e2b361..0ac438a137 100644 --- a/tests/test-conv2d-implicit.cpp +++ b/tests/test-conv2d-implicit.cpp @@ -339,21 +339,21 @@ int main(void) { ggml_time_init(); std::vector> configs = { - std::make_tuple(64,64,48,64), - std::make_tuple(320,320,104,152), - std::make_tuple(640,640,52,76), - std::make_tuple(640,640,104,152), - std::make_tuple(960,320,104,152), - std::make_tuple(1280,1280,26,38), - std::make_tuple(1280,640,52,76), - std::make_tuple(1920,1280,26,38), - std::make_tuple(2560,1280,26,38), - std::make_tuple(512,512,104,152), - std::make_tuple(512,512,208,304), + // std::make_tuple(64,64,48,64), + // std::make_tuple(320,320,104,152), + // std::make_tuple(640,640,52,76), + // std::make_tuple(640,640,104,152), + // std::make_tuple(960,320,104,152), + // std::make_tuple(1280,1280,26,38), + // std::make_tuple(1280,640,52,76), + // std::make_tuple(1920,1280,26,38), + // std::make_tuple(2560,1280,26,38), + // std::make_tuple(512,512,104,152), + // std::make_tuple(512,512,208,304), std::make_tuple(512,256,416,608), - std::make_tuple(256,128,832,1216), - std::make_tuple(256,256,832,1216), - std::make_tuple(320,256,1024,1920) + // std::make_tuple(256,128,832,1216), + // std::make_tuple(256,256,832,1216), + // std::make_tuple(320,256,1024,1920) }; int k = 0;