diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d9fabd9657..31205187c1 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -2,8 +2,8 @@ #include "convert.cuh" typedef struct{ - unsigned int n; //batch szie - unsigned int c; //channel number + unsigned int n; //batch size + unsigned int c; //number if channels unsigned int h; //height unsigned int w; //width unsigned int k; //number of filters @@ -23,23 +23,18 @@ typedef struct{ template static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, - const T * __restrict__ kernel, - float * __restrict__ output, - const param_t param) { + const T * __restrict__ kernel, + float * __restrict__ output, + const param_t param) { - extern __shared__ __align__(16 * 1024) char smem[]; + extern __shared__ unsigned char smem[]; T *smemweight = reinterpret_cast(smem); float *smeminput = reinterpret_cast(smem + 16 * 1024); int tx = threadIdx.x; int bx = blockIdx.x; int by = blockIdx.y; - - // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){ - // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow); - // // printf("param.n=%d\n",param.n); - // } - // __syncthreads(); + // Warp tile const int lane_id = threadIdx.x % 32; @@ -60,8 +55,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int posh_ori[4]; int posw_ori[4]; #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p; posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q; } @@ -82,28 +76,19 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, float input_frag[2][8]; float output_frag[8][8]; #pragma unroll - for (int i = 0; i < 8; ++i) - { + for (int i = 0; i < 8; ++i){ #pragma unroll - for (int j = 0; j < 8; ++j) - { + for (int j = 0; j < 8; ++j){ output_frag[i][j] = 0; } } // ldg - // if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){ - // printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow); - // } - // __syncthreads(); #pragma unroll - for (int i = 0; i < 4; ++i) - { - if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) - { + for (int i = 0; i < 4; ++i){ + if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset]; } - else - { + else{ weight_ldg_reg[i] = (T)0.f; } } @@ -111,57 +96,46 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ int curH = posh_ori[i] + curR * param.d_h; // input h int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c) - { + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){ input_ldg_reg[i] = input[inOffset + inOffsetTmp]; } - else - { + else{ input_ldg_reg[i] = 0.0; } } // sts - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smemweight[weight_sts_addr + i] = weight_ldg_reg[i]; } - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smeminput[input_sts_addr + i * 32] = input_ldg_reg[i]; } __syncthreads(); // lds #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ weight_frag[0][i] = smemweight[weight_lds_addr + i]; weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16]; } #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ input_frag[0][i] = smeminput[input_lds_addr + i]; input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; } - for (int crs = 0; crs < param.r * param.s * param.c; crs += 8) - { + for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){ // ldg int weiOffsetTmp = crs + 8 + tx % 8; #pragma unroll - for (int i = 0; i < 4; ++i) - { - if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k) - { + for (int i = 0; i < 4; ++i){ + if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset]; } - else - { + else{ weight_ldg_reg[i] = (T)0.f; } } @@ -170,76 +144,62 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ int curH = posh_ori[i] + curR * param.d_h; // input h int curW = posw_ori[i] + curS * param.d_w; // input w int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW; - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c) - { + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){ input_ldg_reg[i] = input[inOffset + inOffsetTmp]; } - else - { + else{ input_ldg_reg[i] = 0.f; } } int load_flag = write_flag ^ 1; #pragma unroll - for (int subcrs = 0; subcrs < 8 - 1; ++subcrs) - { + for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){ #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; } #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; } #pragma unroll - for (int i = 0; i < 8; ++i) - { + for (int i = 0; i < 8; ++i){ #pragma unroll - for (int j = 0; j < 8; ++j) - { + for (int j = 0; j < 8; ++j){ output_frag[i][j] += ggml_cuda_cast(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j]; } } } // sts - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i]; } - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i]; } __syncthreads(); write_flag ^= 1; #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i]; weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16]; } #pragma unroll - for (int i = 0; i < 4; ++i) - { + for (int i = 0; i < 4; ++i){ input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i]; input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32]; } #pragma unroll - for (int i = 0; i < 8; ++i) - { + for (int i = 0; i < 8; ++i){ #pragma unroll - for (int j = 0; j < 8; ++j) - { + for (int j = 0; j < 8; ++j){ output_frag[i][j] += ggml_cuda_cast(weight_frag[1][i]) * input_frag[1][j]; } } @@ -247,35 +207,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // reuse smem float *smemoutput = reinterpret_cast(smem); - // float *smembias = reinterpret_cast(smem + 16 * 1024); - // bias ldg/sts - // if (tx < 128) - // { - // smembias[tx] = param.bias[by * 128 + tx]; - // } uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4; uint32_t output_lds_addr = warp_id * 512 + lane_id; - // uint32_t bias_lds_addr = warp_id / 2 * 32; uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32; uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id; #pragma unroll - for (int i = 0; i < 2; ++i) - { + for (int i = 0; i < 2; ++i){ #pragma unroll - for (int j = 0; j < 2; ++j) - { + for (int j = 0; j < 2; ++j){ __syncthreads(); - #pragma unroll - for (int subi = 0; subi < 4; ++subi) - { + for (int subi = 0; subi < 4; ++subi){ #pragma unroll - for (int subj = 0; subj < 4; ++subj) - { + for (int subj = 0; subj < 4; ++subj){ // output sts smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj]; } @@ -283,11 +231,9 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, __syncthreads(); #pragma unroll - for (int subk = 0; subk < 16; ++subk) - { + for (int subk = 0; subk < 16; ++subk){ int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32; if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) - // output[outOffset] = smemoutput[output_lds_addr + subk * 32] + smembias[bias_lds_addr + i * 16 + subk]; output[outOffset] = smemoutput[output_lds_addr + subk * 32]; } } @@ -295,8 +241,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, } template -static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { - // const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number int blocky = (P.k + 127) / 128; // blocky number int blockz = P.n; // blockz number