WIP: build ok
This commit is contained in:
parent
52455b8a6d
commit
0a64ea8ff8
|
|
@ -62,6 +62,29 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
|
|||
}
|
||||
}
|
||||
|
||||
template<const int layout>
|
||||
__device__ int4 inputIndices(const uint kidx, param_t param) {
|
||||
|
||||
const uint cur0 = fastdiv(kidx,
|
||||
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
|
||||
const uint cur0_res = fastmodulo(kidx,
|
||||
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
|
||||
const uint cur1 = fastdiv(cur0_res,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
|
||||
const uint cur1_res = fastmodulo(cur0_res,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
|
||||
const uint cur2 = fastdiv(cur1_res,
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint cur3 = fastmodulo(cur1_res,
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint curC = layout == 0 ? cur3 : cur0;
|
||||
const uint curT = layout == 0 ? cur0 : cur1;
|
||||
const uint curR = layout == 0 ? cur1 : cur2;
|
||||
const uint curS = layout == 0 ? cur2 : cur3;
|
||||
return make_int4(curC, curT, curR, curS);
|
||||
|
||||
}
|
||||
|
||||
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
|
||||
const int WNITER, const int TM, const int TN, const int NUM_THREADS,
|
||||
// layout: 0, NHWC; 1, NCHW
|
||||
|
|
@ -80,7 +103,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
const uint bx = blockIdx.x;
|
||||
const uint by = blockIdx.y;
|
||||
|
||||
const uint PQZ = param.Oh * param.Ow * param.Oz;
|
||||
const uint PQZ = param.Oh * param.Ow * param.Od;
|
||||
|
||||
// Warp tile
|
||||
const uint lane_id = tx % WARPSIZE;
|
||||
|
|
@ -102,6 +125,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w;
|
||||
int inDepthOffset = layout == 0 ? param.h * param.c * param.w : param.d * param.h * param.w;
|
||||
int weightKOffset = param.c * param.r * param.s * param.t;
|
||||
int inNOffset = param.c * param.w * param.h * param.d;
|
||||
|
||||
const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
|
||||
const uint start_k = (ksplit > 0)? z * ks: 0;
|
||||
|
|
@ -158,31 +182,44 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
|
||||
|
||||
const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
|
||||
const uint inKOffset = start_k + innerColA * 4;
|
||||
#pragma unroll
|
||||
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
||||
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z;
|
||||
const unsigned int npqz_res = (bx * BM + innerRowA + offset) % PQZ;
|
||||
const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv) * param.stride2 - param.padding2;
|
||||
const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: bx * BM + innerRowA + offset, param.OWOH_fastdiv);
|
||||
const unsigned int gemm_i = bx * BM + innerRowA + offset;
|
||||
// int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQZ : z;
|
||||
int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z;
|
||||
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
|
||||
const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv) * param.stride2 - param.padding2;
|
||||
const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv);
|
||||
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||
int inOffset = n * param.c * param.h * param.w * param.d;
|
||||
int inOffset = n * inNOffset;
|
||||
if(vec_load){
|
||||
const uint cur0 = fastdiv(start_k + innerColA * 4,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint curC = layout == 0 ? cur2 : cur0;
|
||||
const uint curR = layout == 0 ? cur0 : cur1;
|
||||
const uint curS = layout == 0 ? cur1 : cur2;
|
||||
const int curD = posd_ori + curT * param.dilation2; // input w
|
||||
const int curH = posh_ori + curR * param.dilation1; // input h
|
||||
const int curW = posw_ori + curS * param.dilation0; // input w
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && start_k + innerColA * 4 < end_k){
|
||||
// const uint cur0 = fastdiv(inKOffset,
|
||||
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
|
||||
// const uint cur0_res = fastmodulo(inKOffset,
|
||||
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
|
||||
// const uint cur1 = fastdiv(cur0_res,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
|
||||
// const uint cur1_res = fastmodulo(cur0_res,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
|
||||
// const uint cur2 = fastdiv(cur1_res,
|
||||
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
// const uint cur3 = fastmodulo(cur1_res,
|
||||
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
// const uint curC = layout == 0 ? cur3 : cur0;
|
||||
// const uint curT = layout == 0 ? cur0 : cur1;
|
||||
// const uint curR = layout == 0 ? cur1 : cur2;
|
||||
// const uint curS = layout == 0 ? cur2 : cur3;
|
||||
const int4 curIdx = inputIndices<layout>(inKOffset, param);
|
||||
// const int curD = posd_ori + curT * param.dilation2; // input w
|
||||
// const int curH = posh_ori + curR * param.dilation1; // input h
|
||||
// const int curW = posw_ori + curS * param.dilation0; // input w
|
||||
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
|
||||
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
|
||||
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
|
||||
const int curC = curIdx.x;
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset < end_k){
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
|
||||
|
|
@ -199,23 +236,47 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i){
|
||||
const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint curC = layout == 0 ? cur2 : cur0;
|
||||
const uint curR = layout == 0 ? cur0 : cur1;
|
||||
const uint curS = layout == 0 ? cur1 : cur2;
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){
|
||||
// const uint cur0 = fastdiv(inKOffset + i,
|
||||
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
|
||||
// const uint cur0_res = fastmodulo(inKOffset + i,
|
||||
// layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv); // channel offset
|
||||
// const uint cur1 = fastdiv(cur0_res,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
|
||||
// const uint cur1_res = fastmodulo(cur0_res,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // kernel r offset
|
||||
// const uint cur2 = fastdiv(cur1_res,
|
||||
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
// const uint cur3 = fastmodulo(cur1_res,
|
||||
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
// const uint curC = layout == 0 ? cur3 : cur0;
|
||||
// const uint curT = layout == 0 ? cur0 : cur1;
|
||||
// const uint curR = layout == 0 ? cur1 : cur2;
|
||||
// const uint curS = layout == 0 ? cur2 : cur3;
|
||||
const int4 curIdx = inputIndices<layout>(inKOffset + i, param);
|
||||
// const int curD = posd_ori + curT * param.dilation2; // input w
|
||||
// const int curH = posh_ori + curR * param.dilation1; // input h
|
||||
// const int curW = posw_ori + curS * param.dilation0; // input w
|
||||
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
|
||||
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
|
||||
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
|
||||
const int curC = curIdx.x;
|
||||
// const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
// const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
// const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
// const uint curC = layout == 0 ? cur2 : cur0;
|
||||
// const uint curR = layout == 0 ? cur0 : cur1;
|
||||
// const uint curS = layout == 0 ? cur1 : cur2;
|
||||
// const int curH = posh_ori + curR * param.d_h; // input h
|
||||
// const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKOffset + i < end_k){
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
|
||||
} else {
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
|
|
@ -242,6 +303,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN +
|
||||
threadColInWarp * TN + i];
|
||||
|
||||
// main block k loop
|
||||
for (int crs = start_k; crs < end_k; crs += BK) {
|
||||
|
||||
int load_flag = write_flag ^ 1;
|
||||
|
|
@ -317,33 +379,50 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
}
|
||||
}
|
||||
}
|
||||
const uint inKkOffset = innerColA * 4 + crs + BK;
|
||||
#pragma unroll
|
||||
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
||||
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
|
||||
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
|
||||
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
|
||||
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
|
||||
int inOffset = n * param.c * param.h * param.w ;
|
||||
// int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
|
||||
// const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
|
||||
// const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
|
||||
// const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
|
||||
// int inOffset = n * param.c * param.h * param.w ;
|
||||
const unsigned int gemm_i = bx * BM + innerRowA + offset;
|
||||
int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z;
|
||||
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
|
||||
const int posd_ori = fastdiv((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv) * param.stride2 - param.padding2;
|
||||
const int ohow_res = fastmodulo((ksplit > 0) ? npqz_res: gemm_i, param.OHOW_fastdiv);
|
||||
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||
int inOffset = n * inNOffset;
|
||||
if(vec_load){
|
||||
const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint curC = layout == 0 ? cur2 : cur0;
|
||||
const uint curR = layout == 0 ? cur0 : cur1;
|
||||
const uint curS = layout == 0 ? cur1 : cur2;
|
||||
const int4 curIdx = inputIndices<layout>(inKkOffset, param);
|
||||
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
|
||||
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
|
||||
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
|
||||
const int curC = curIdx.x;
|
||||
// const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
// const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
// const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
// const uint curC = layout == 0 ? cur2 : cur0;
|
||||
// const uint curR = layout == 0 ? cur0 : cur1;
|
||||
// const uint curS = layout == 0 ? cur1 : cur2;
|
||||
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){
|
||||
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
// const int curH = posh_ori + curR * param.d_h; // input h
|
||||
// const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset < end_k){
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
|
||||
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && inKkOffset < end_k){
|
||||
// int inOffsetTmp = layout == 0 ?
|
||||
// curH * inChannelOffset + curW * param.c + curC:
|
||||
// curC * inChannelOffset + curH * param.w + curW;
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y;
|
||||
|
|
@ -357,24 +436,33 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i){
|
||||
const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint curC = layout == 0 ? cur2 : cur0;
|
||||
const uint curR = layout == 0 ? cur0 : cur1;
|
||||
const uint curS = layout == 0 ? cur1 : cur2;
|
||||
// const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
// const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
// const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i,
|
||||
// layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
// layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
// const uint curC = layout == 0 ? cur2 : cur0;
|
||||
// const uint curR = layout == 0 ? cur0 : cur1;
|
||||
// const uint curS = layout == 0 ? cur1 : cur2;
|
||||
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
|
||||
// const int curH = posh_ori + curR * param.d_h; // input h
|
||||
// const int curW = posw_ori + curS * param.d_w; // input w
|
||||
const int4 curIdx = inputIndices<layout>(inKkOffset + i, param);
|
||||
const int curD = posd_ori + curIdx.y * param.dilation2; // input w
|
||||
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
|
||||
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
|
||||
const int curC = curIdx.x;
|
||||
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
|
||||
// int inOffsetTmp = layout == 0 ?
|
||||
// curH * inChannelOffset + curW * param.c + curC:
|
||||
// curC * inChannelOffset + curH * param.w + curW;
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && inKkOffset + i < end_k){
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inDepthOffset + curD * inChannelOffset + curH * param.w + curW;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
|
||||
} else {
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
|
|
@ -448,15 +536,16 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int subk = 0; subk < TM * TN; ++subk){
|
||||
// output: [N*OC, OD, OH, OW]
|
||||
const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM;
|
||||
const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM;
|
||||
const int n = (ksplit > 0) ? gemm_i / PQ : z;
|
||||
const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i;
|
||||
if (n < param.n && row < param.k && col < param.Oh * param.Ow){
|
||||
const int n = (ksplit > 0) ? fastdiv(gemm_i, param.PQZ_fastdiv) : z;
|
||||
const int col = (ksplit > 0) ? fastmodulo(gemm_i, param.PQZ_fastdiv) : gemm_i;
|
||||
if (n < param.n && row < param.k && col < PQZ){
|
||||
const uint outOffset = ksplit > 0 ?
|
||||
z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
|
||||
row * param.Oh * param.Ow + col :
|
||||
z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||
// z * param.n * param.k * PQZ + n * param.k * PQZ + row * PQZ + col :
|
||||
((z * param.n + n) * param.k + row) * PQZ + col :
|
||||
(z * param.k + row) * PQZ + col;
|
||||
output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE];
|
||||
}
|
||||
}
|
||||
|
|
@ -464,7 +553,7 @@ static __global__ void conv3d_implicit_kernel(const float * __restrict__ input,
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#if 0
|
||||
|
||||
template <unsigned int mma_tiles_per_warp_m, unsigned int mma_tiles_per_warp_k, unsigned int smem_stride>
|
||||
__device__ __forceinline__ void ldmatrix_a(
|
||||
|
|
@ -885,6 +974,7 @@ constexpr unsigned int MMA_N = 8;
|
|||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#define NUM_VARIANTS 4
|
||||
|
||||
|
|
@ -925,70 +1015,70 @@ static void conv3d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
|
|||
dim3 thblock(NUM_THREADS, thready, threadz);
|
||||
dim3 grid(blockx, blocky, blockz);
|
||||
|
||||
conv2d_implicit_kernel<T, BM, BN, BK, WM, WN,
|
||||
conv3d_implicit_kernel<T, BM, BN, BK, WM, WN,
|
||||
WNITER, TM, TN, NUM_THREADS, 1, false, 0><<<grid, thblock, 0, st>>>(X_D, K_D, Y_D, P);
|
||||
}
|
||||
|
||||
static void conv3d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
|
||||
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) {
|
||||
// if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1 || P.t > 1)) {
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
// int id = ggml_cuda_get_device();
|
||||
|
||||
int64_t ne = P.c * P.h * P.w * P.n;
|
||||
int64_t ne00 = P.c;
|
||||
int64_t ne01 = P.h * P.w;
|
||||
ggml_cuda_pool_alloc<half> input_f16(ctx.pool(id), ne);
|
||||
// int64_t ne = P.c * P.h * P.w * P.n;
|
||||
// int64_t ne00 = P.c;
|
||||
// int64_t ne01 = P.h * P.w;
|
||||
// ggml_cuda_pool_alloc<half> input_f16(ctx.pool(id), ne);
|
||||
|
||||
dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
|
||||
dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1);
|
||||
NCHW2NHWC<float, half><<<dimGrid, dimBlock, 0, st>>>(X_D, input_f16.get(), ne, ne00, ne01);
|
||||
// dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
// (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
// (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
|
||||
// dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1);
|
||||
// NCHW2NHWC<float, half><<<dimGrid, dimBlock, 0, st>>>(X_D, input_f16.get(), ne, ne00, ne01);
|
||||
|
||||
ne = P.c * P.r * P.s * P.k;
|
||||
ne01 = P.r * P.s;
|
||||
ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
|
||||
dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
(ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
|
||||
NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
|
||||
// ne = P.c * P.r * P.s * P.k;
|
||||
// ne01 = P.r * P.s;
|
||||
// ggml_cuda_pool_alloc<half> kernel_f16(ctx.pool(id), ne);
|
||||
// dim3 dimGrid1((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
// (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM,
|
||||
// (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ;
|
||||
// NCHW2NHWC<half, half><<<dimGrid1, dimBlock, 0, st>>>(K_D, kernel_f16.get(), ne, ne00, ne01);
|
||||
|
||||
const half *X_H = input_f16.get();
|
||||
const half *K_H = kernel_f16.get();
|
||||
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
|
||||
// const half *X_H = input_f16.get();
|
||||
// const half *K_H = kernel_f16.get();
|
||||
// ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), P.k * P.Oh * P.Ow * P.n);
|
||||
|
||||
constexpr unsigned int BM_dim = 256;
|
||||
constexpr unsigned int BN_dim = 256;
|
||||
constexpr unsigned int BK_dim = 32;
|
||||
// constexpr unsigned int BM_dim = 256;
|
||||
// constexpr unsigned int BN_dim = 256;
|
||||
// constexpr unsigned int BK_dim = 32;
|
||||
|
||||
constexpr unsigned int WARPS_PER_BLOCK_M = 2;
|
||||
constexpr unsigned int WARPS_PER_BLOCK_N = 4;
|
||||
constexpr unsigned int WARPS_PER_BLOCK_K = 4;
|
||||
// constexpr unsigned int WARPS_PER_BLOCK_M = 2;
|
||||
// constexpr unsigned int WARPS_PER_BLOCK_N = 4;
|
||||
// constexpr unsigned int WARPS_PER_BLOCK_K = 4;
|
||||
|
||||
constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
|
||||
constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
|
||||
constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
|
||||
const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim;
|
||||
const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
|
||||
constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
|
||||
constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N;
|
||||
constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
|
||||
const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
||||
// constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
|
||||
// constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
|
||||
// constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
|
||||
// const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim;
|
||||
// const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
|
||||
// constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
|
||||
// constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N;
|
||||
// constexpr unsigned int NumThreads = ThreadsM * ThreadsN;
|
||||
// const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half);
|
||||
|
||||
cudaFuncSetAttribute(conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
|
||||
dim3 gridDim(BlocksN, BlocksM);
|
||||
dim3 blockDim(ThreadsN, ThreadsM);
|
||||
// cudaFuncSetAttribute(conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, NumThreads>,
|
||||
// cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
|
||||
// dim3 gridDim(BlocksN, BlocksM);
|
||||
// dim3 blockDim(ThreadsN, ThreadsM);
|
||||
|
||||
conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim,
|
||||
WM_dim, WN_dim, WK_dim, NumThreads>
|
||||
<<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
|
||||
} else{
|
||||
// conv3d_implicit_kernel<BM_dim, BN_dim, BK_dim,
|
||||
// WM_dim, WN_dim, WK_dim, NumThreads>
|
||||
// <<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
|
||||
// const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
// to_fp32_cuda(Y_H.get(), Y_D, P.k * P.Oh * P.Ow * P.n, st);
|
||||
// } else{
|
||||
conv3d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
|
||||
}
|
||||
// }
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -1056,7 +1146,10 @@ void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
init_fastdiv_values(IC),
|
||||
init_fastdiv_values(KW*KH),
|
||||
init_fastdiv_values(KW),
|
||||
init_fastdiv_values(OW*OH)};
|
||||
init_fastdiv_values(OW*OH),
|
||||
init_fastdiv_values(OW*OH*OD),
|
||||
init_fastdiv_values(KW*KH*IC),
|
||||
init_fastdiv_values(KW*KH*KD)};
|
||||
|
||||
if (kernel->type == GGML_TYPE_F16) {
|
||||
conv3d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st);
|
||||
|
|
|
|||
|
|
@ -29,9 +29,13 @@ typedef struct{
|
|||
uint3 RS_fastdiv;
|
||||
uint3 S_fastdiv;
|
||||
uint3 OHOW_fastdiv;
|
||||
uint3 PQZ_fastdiv;
|
||||
uint3 RSC_fastdiv;
|
||||
uint3 TRS_fastdiv;
|
||||
} param_t;
|
||||
|
||||
|
||||
|
||||
// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int NUM_THREADS>
|
||||
|
|
@ -131,14 +135,14 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
||||
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
const unsigned int curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
int curH = posh_ori + curR * param.d_h; // input h
|
||||
int curW = posw_ori + curS * param.d_w; // input w
|
||||
int curH = posh_ori + curR * param.dilation1; // input h
|
||||
int curW = posw_ori + curS * param.dilation0; // input w
|
||||
// apply swizzle to the dst index
|
||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
|
|
@ -197,14 +201,14 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
||||
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.stride1 - param.padding1;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.stride0 - param.padding0;
|
||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
const unsigned int curR = fastdiv(block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
int curH = posh_ori + curR * param.d_h; // input h
|
||||
int curW = posw_ori + curS * param.d_w; // input w
|
||||
int curH = posh_ori + curR * param.dilation1; // input h
|
||||
int curW = posw_ori + curS * param.dilation0; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
curR < param.r && curS < param.s && curC < param.c){
|
||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
|
|
@ -348,6 +352,5 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
|
|||
return address;
|
||||
}
|
||||
|
||||
|
||||
#define CUDA_CONV3D_IMPLICT_BLOCK_SIZE 256
|
||||
void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
|||
Loading…
Reference in New Issue