Refactor conv2d_implicit_kernel for improved readability and consistency; update parameter comments and remove unused code
This commit is contained in:
parent
5ffe97be9c
commit
4b0f9d571f
|
|
@ -2,8 +2,8 @@
|
|||
#include "convert.cuh"
|
||||
|
||||
typedef struct{
|
||||
unsigned int n; //batch szie
|
||||
unsigned int c; //channel number
|
||||
unsigned int n; //batch size
|
||||
unsigned int c; //number if channels
|
||||
unsigned int h; //height
|
||||
unsigned int w; //width
|
||||
unsigned int k; //number of filters
|
||||
|
|
@ -23,23 +23,18 @@ typedef struct{
|
|||
|
||||
template <typename T>
|
||||
static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||
const T * __restrict__ kernel,
|
||||
float * __restrict__ output,
|
||||
const param_t param) {
|
||||
const T * __restrict__ kernel,
|
||||
float * __restrict__ output,
|
||||
const param_t param) {
|
||||
|
||||
extern __shared__ __align__(16 * 1024) char smem[];
|
||||
extern __shared__ unsigned char smem[];
|
||||
T *smemweight = reinterpret_cast<T *>(smem);
|
||||
float *smeminput = reinterpret_cast<float *>(smem + 16 * 1024);
|
||||
|
||||
int tx = threadIdx.x;
|
||||
int bx = blockIdx.x;
|
||||
int by = blockIdx.y;
|
||||
|
||||
// if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
|
||||
// printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
|
||||
// // printf("param.n=%d\n",param.n);
|
||||
// }
|
||||
// __syncthreads();
|
||||
|
||||
|
||||
// Warp tile
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
|
|
@ -60,8 +55,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
int posh_ori[4];
|
||||
int posw_ori[4];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p;
|
||||
posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q;
|
||||
}
|
||||
|
|
@ -82,28 +76,19 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
float input_frag[2][8];
|
||||
float output_frag[8][8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; ++i)
|
||||
{
|
||||
for (int i = 0; i < 8; ++i){
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; ++j)
|
||||
{
|
||||
for (int j = 0; j < 8; ++j){
|
||||
output_frag[i][j] = 0;
|
||||
}
|
||||
}
|
||||
// ldg
|
||||
// if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
|
||||
// printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
|
||||
// }
|
||||
// __syncthreads();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){
|
||||
weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset];
|
||||
}
|
||||
else
|
||||
{
|
||||
else{
|
||||
weight_ldg_reg[i] = (T)0.f;
|
||||
}
|
||||
}
|
||||
|
|
@ -111,57 +96,46 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset
|
||||
int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
int curH = posh_ori[i] + curR * param.d_h; // input h
|
||||
int curW = posw_ori[i] + curS * param.d_w; // input w
|
||||
int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c)
|
||||
{
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){
|
||||
input_ldg_reg[i] = input[inOffset + inOffsetTmp];
|
||||
}
|
||||
else
|
||||
{
|
||||
else{
|
||||
input_ldg_reg[i] = 0.0;
|
||||
}
|
||||
}
|
||||
// sts
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
smemweight[weight_sts_addr + i] = weight_ldg_reg[i];
|
||||
}
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
smeminput[input_sts_addr + i * 32] = input_ldg_reg[i];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// lds
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
weight_frag[0][i] = smemweight[weight_lds_addr + i];
|
||||
weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
input_frag[0][i] = smeminput[input_lds_addr + i];
|
||||
input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32];
|
||||
}
|
||||
for (int crs = 0; crs < param.r * param.s * param.c; crs += 8)
|
||||
{
|
||||
for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){
|
||||
// ldg
|
||||
int weiOffsetTmp = crs + 8 + tx % 8;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){
|
||||
weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset];
|
||||
}
|
||||
else
|
||||
{
|
||||
else{
|
||||
weight_ldg_reg[i] = (T)0.f;
|
||||
}
|
||||
}
|
||||
|
|
@ -170,76 +144,62 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
int curH = posh_ori[i] + curR * param.d_h; // input h
|
||||
int curW = posw_ori[i] + curS * param.d_w; // input w
|
||||
int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c)
|
||||
{
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){
|
||||
input_ldg_reg[i] = input[inOffset + inOffsetTmp];
|
||||
}
|
||||
else
|
||||
{
|
||||
else{
|
||||
input_ldg_reg[i] = 0.f;
|
||||
}
|
||||
}
|
||||
int load_flag = write_flag ^ 1;
|
||||
#pragma unroll
|
||||
for (int subcrs = 0; subcrs < 8 - 1; ++subcrs)
|
||||
{
|
||||
for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i];
|
||||
weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i];
|
||||
input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; ++i)
|
||||
{
|
||||
for (int i = 0; i < 8; ++i){
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; ++j)
|
||||
{
|
||||
for (int j = 0; j < 8; ++j){
|
||||
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
// sts
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i];
|
||||
}
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i];
|
||||
}
|
||||
__syncthreads();
|
||||
write_flag ^= 1;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i];
|
||||
weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
{
|
||||
for (int i = 0; i < 4; ++i){
|
||||
input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i];
|
||||
input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; ++i)
|
||||
{
|
||||
for (int i = 0; i < 8; ++i){
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; ++j)
|
||||
{
|
||||
for (int j = 0; j < 8; ++j){
|
||||
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][i]) * input_frag[1][j];
|
||||
}
|
||||
}
|
||||
|
|
@ -247,35 +207,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
|
||||
// reuse smem
|
||||
float *smemoutput = reinterpret_cast<float *>(smem);
|
||||
// float *smembias = reinterpret_cast<float *>(smem + 16 * 1024);
|
||||
|
||||
// bias ldg/sts
|
||||
// if (tx < 128)
|
||||
// {
|
||||
// smembias[tx] = param.bias[by * 128 + tx];
|
||||
// }
|
||||
|
||||
uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4;
|
||||
uint32_t output_lds_addr = warp_id * 512 + lane_id;
|
||||
// uint32_t bias_lds_addr = warp_id / 2 * 32;
|
||||
|
||||
uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32;
|
||||
uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i)
|
||||
{
|
||||
for (int i = 0; i < 2; ++i){
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; ++j)
|
||||
{
|
||||
for (int j = 0; j < 2; ++j){
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int subi = 0; subi < 4; ++subi)
|
||||
{
|
||||
for (int subi = 0; subi < 4; ++subi){
|
||||
#pragma unroll
|
||||
for (int subj = 0; subj < 4; ++subj)
|
||||
{
|
||||
for (int subj = 0; subj < 4; ++subj){
|
||||
// output sts
|
||||
smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj];
|
||||
}
|
||||
|
|
@ -283,11 +231,9 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int subk = 0; subk < 16; ++subk)
|
||||
{
|
||||
for (int subk = 0; subk < 16; ++subk){
|
||||
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32;
|
||||
if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
||||
// output[outOffset] = smemoutput[output_lds_addr + subk * 32] + smembias[bias_lds_addr + i * 16 + subk];
|
||||
output[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
||||
}
|
||||
}
|
||||
|
|
@ -295,8 +241,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
|
||||
// const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
|
||||
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
|
||||
int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number
|
||||
int blocky = (P.k + 127) / 128; // blocky number
|
||||
int blockz = P.n; // blockz number
|
||||
|
|
|
|||
Loading…
Reference in New Issue