Refactor conv2d_implicit_kernel for improved readability and consistency; update parameter comments and remove unused code

This commit is contained in:
bssrdf 2025-09-05 08:29:57 -04:00
parent 5ffe97be9c
commit 4b0f9d571f
1 changed files with 44 additions and 99 deletions

View File

@ -2,8 +2,8 @@
#include "convert.cuh"
typedef struct{
unsigned int n; //batch szie
unsigned int c; //channel number
unsigned int n; //batch size
unsigned int c; //number if channels
unsigned int h; //height
unsigned int w; //width
unsigned int k; //number of filters
@ -23,23 +23,18 @@ typedef struct{
template <typename T>
static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
const T * __restrict__ kernel,
float * __restrict__ output,
const param_t param) {
const T * __restrict__ kernel,
float * __restrict__ output,
const param_t param) {
extern __shared__ __align__(16 * 1024) char smem[];
extern __shared__ unsigned char smem[];
T *smemweight = reinterpret_cast<T *>(smem);
float *smeminput = reinterpret_cast<float *>(smem + 16 * 1024);
int tx = threadIdx.x;
int bx = blockIdx.x;
int by = blockIdx.y;
// if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
// printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
// // printf("param.n=%d\n",param.n);
// }
// __syncthreads();
// Warp tile
const int lane_id = threadIdx.x % 32;
@ -60,8 +55,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
int posh_ori[4];
int posw_ori[4];
#pragma unroll
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p;
posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q;
}
@ -82,28 +76,19 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
float input_frag[2][8];
float output_frag[8][8];
#pragma unroll
for (int i = 0; i < 8; ++i)
{
for (int i = 0; i < 8; ++i){
#pragma unroll
for (int j = 0; j < 8; ++j)
{
for (int j = 0; j < 8; ++j){
output_frag[i][j] = 0;
}
}
// ldg
// if(tx == 0 && bx == 0 && by == 0 && blockIdx.z == 0){
// printf("param.n=%d, param.c=%d, param.h=%d, param.w=%d, param.k=%d, param.r=%d, param.s=%d, param.u=%d, param.v=%d, param.p=%d, param.q=%d, param.d_h=%d, param.d_w=%d, param.Oh=%d, param.Ow=%d\n",param.n,param.c,param.h,param.w,param.k,param.r,param.s,param.u,param.v,param.p,param.q,param.d_h,param.d_w,param.Oh,param.Ow);
// }
// __syncthreads();
#pragma unroll
for (int i = 0; i < 4; ++i)
{
if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k)
{
for (int i = 0; i < 4; ++i){
if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){
weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset];
}
else
{
else{
weight_ldg_reg[i] = (T)0.f;
}
}
@ -111,57 +96,46 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset
int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset
#pragma unroll
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
int curH = posh_ori[i] + curR * param.d_h; // input h
int curW = posw_ori[i] + curS * param.d_w; // input w
int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c)
{
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){
input_ldg_reg[i] = input[inOffset + inOffsetTmp];
}
else
{
else{
input_ldg_reg[i] = 0.0;
}
}
// sts
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
smemweight[weight_sts_addr + i] = weight_ldg_reg[i];
}
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
smeminput[input_sts_addr + i * 32] = input_ldg_reg[i];
}
__syncthreads();
// lds
#pragma unroll
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
weight_frag[0][i] = smemweight[weight_lds_addr + i];
weight_frag[0][i + 4] = smemweight[weight_lds_addr + i + 16];
}
#pragma unroll
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
input_frag[0][i] = smeminput[input_lds_addr + i];
input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32];
}
for (int crs = 0; crs < param.r * param.s * param.c; crs += 8)
{
for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){
// ldg
int weiOffsetTmp = crs + 8 + tx % 8;
#pragma unroll
for (int i = 0; i < 4; ++i)
{
if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k)
{
for (int i = 0; i < 4; ++i){
if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){
weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset];
}
else
{
else{
weight_ldg_reg[i] = (T)0.f;
}
}
@ -170,76 +144,62 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset
#pragma unroll
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
int curH = posh_ori[i] + curR * param.d_h; // input h
int curW = posw_ori[i] + curS * param.d_w; // input w
int inOffsetTmp = curC * inChannelOffset + curH * param.w + curW;
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c)
{
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && curC < param.c){
input_ldg_reg[i] = input[inOffset + inOffsetTmp];
}
else
{
else{
input_ldg_reg[i] = 0.f;
}
}
int load_flag = write_flag ^ 1;
#pragma unroll
for (int subcrs = 0; subcrs < 8 - 1; ++subcrs)
{
for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){
#pragma unroll
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i];
weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16];
}
#pragma unroll
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i];
input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32];
}
#pragma unroll
for (int i = 0; i < 8; ++i)
{
for (int i = 0; i < 8; ++i){
#pragma unroll
for (int j = 0; j < 8; ++j)
{
for (int j = 0; j < 8; ++j){
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j];
}
}
}
// sts
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
smemweight[write_flag * 132 * 8 + weight_sts_addr + i] = weight_ldg_reg[i];
}
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
smeminput[write_flag * 128 * 8 + input_sts_addr + i * 32] = input_ldg_reg[i];
}
__syncthreads();
write_flag ^= 1;
#pragma unroll
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
weight_frag[0][i] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i];
weight_frag[0][i + 4] = smemweight[(load_flag ^ 1) * 132 * 8 + weight_lds_addr + i + 16];
}
#pragma unroll
for (int i = 0; i < 4; ++i)
{
for (int i = 0; i < 4; ++i){
input_frag[0][i] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i];
input_frag[0][i + 4] = smeminput[(load_flag ^ 1) * 128 * 8 + input_lds_addr + i + 32];
}
#pragma unroll
for (int i = 0; i < 8; ++i)
{
for (int i = 0; i < 8; ++i){
#pragma unroll
for (int j = 0; j < 8; ++j)
{
for (int j = 0; j < 8; ++j){
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][i]) * input_frag[1][j];
}
}
@ -247,35 +207,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
// reuse smem
float *smemoutput = reinterpret_cast<float *>(smem);
// float *smembias = reinterpret_cast<float *>(smem + 16 * 1024);
// bias ldg/sts
// if (tx < 128)
// {
// smembias[tx] = param.bias[by * 128 + tx];
// }
uint32_t output_sts_addr = warp_id * 512 + mma_tid_y * 4 * 8 * 4 + mma_tid_x * 4;
uint32_t output_lds_addr = warp_id * 512 + lane_id;
// uint32_t bias_lds_addr = warp_id / 2 * 32;
uint32_t m_idx = blockIdx.y * 128 + warp_id / 2 * 32;
uint32_t n_idx = blockIdx.x * 128 + warp_id % 2 * 64 + lane_id;
#pragma unroll
for (int i = 0; i < 2; ++i)
{
for (int i = 0; i < 2; ++i){
#pragma unroll
for (int j = 0; j < 2; ++j)
{
for (int j = 0; j < 2; ++j){
__syncthreads();
#pragma unroll
for (int subi = 0; subi < 4; ++subi)
{
for (int subi = 0; subi < 4; ++subi){
#pragma unroll
for (int subj = 0; subj < 4; ++subj)
{
for (int subj = 0; subj < 4; ++subj){
// output sts
smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj];
}
@ -283,11 +231,9 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
__syncthreads();
#pragma unroll
for (int subk = 0; subk < 16; ++subk)
{
for (int subk = 0; subk < 16; ++subk){
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32;
if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
// output[outOffset] = smemoutput[output_lds_addr + subk * 32] + smembias[bias_lds_addr + i * 16 + subk];
output[outOffset] = smemoutput[output_lds_addr + subk * 32];
}
}
@ -295,8 +241,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
}
template <typename T>
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
// const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) {
int blockx = ((P.Oh * P.Ow + 127) / 128); // blockx number
int blocky = (P.k + 127) / 128; // blocky number
int blockz = P.n; // blockz number