Merge branch 'refactor-cuda-core-path' into conv2d-implicit
This commit is contained in:
commit
28b7094750
|
|
@ -27,6 +27,7 @@ static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restri
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename src_T, typename dst_T>
|
||||
static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){
|
||||
|
||||
|
|
@ -63,6 +64,8 @@ static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, co
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
|
||||
const int WNITER, const int TM, const int TN, const int NUM_THREADS,
|
||||
// layout: 0, NHWC; 1, NCHW
|
||||
|
|
@ -82,6 +85,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
const uint by = blockIdx.y;
|
||||
|
||||
const uint PQ = param.Oh * param.Ow;
|
||||
const uint CHW = param.c * param.h * param.w;
|
||||
|
||||
// Warp tile
|
||||
const uint lane_id = tx % WARPSIZE;
|
||||
|
|
@ -119,107 +123,14 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
constexpr uint rowStrideA = (NUM_THREADS * 4) / BK;
|
||||
|
||||
// ldg
|
||||
const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4;
|
||||
#pragma unroll
|
||||
for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
|
||||
if(vec_load){
|
||||
if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < end_k){
|
||||
if constexpr (std::is_same_v<T, float>){
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
|
||||
smemweight[weight_sts_addr + offset + 0] = tmp.x;
|
||||
smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y;
|
||||
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z;
|
||||
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w;
|
||||
}else{ // read 4 halves
|
||||
float2 tmp = reinterpret_cast<const float2 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4])[0];
|
||||
const half *val = reinterpret_cast<const half *>(&tmp);
|
||||
smemweight[weight_sts_addr + offset + 0] = val[0];
|
||||
smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1];
|
||||
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2];
|
||||
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i){
|
||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
||||
}
|
||||
}
|
||||
}else{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i){
|
||||
if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 + i < end_k){
|
||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + start_k + innerColA * 4 + i];
|
||||
} else {
|
||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
loadFilter<T, BN, rowStrideA, layout, vec_load, ksplit, PAD>
|
||||
(kernel, smemweight, by, innerRowA, innerColA, weightKOffset,
|
||||
start_k, end_k, param);
|
||||
|
||||
loadInput<BM, rowStrideA, layout, vec_load, ksplit, PAD>
|
||||
(input, smeminput, bx, innerRowA, innerColA,
|
||||
start_k, end_k, PQ, CHW, inChannelOffset, param);
|
||||
|
||||
const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
|
||||
#pragma unroll
|
||||
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
||||
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
|
||||
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
|
||||
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
|
||||
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
|
||||
int inOffset = n * param.c * param.h * param.w ;
|
||||
if(vec_load){
|
||||
const uint cur0 = fastdiv(start_k + innerColA * 4,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint curC = layout == 0 ? cur2 : cur0;
|
||||
const uint curR = layout == 0 ? cur0 : cur1;
|
||||
const uint curS = layout == 0 ? cur1 : cur2;
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
|
||||
smeminput[input_sts_addr + offset + 0] = tmp.x;
|
||||
smeminput[input_sts_addr + offset + BM+PAD] = tmp.y;
|
||||
smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z;
|
||||
smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i){
|
||||
const uint cur0 = fastdiv(start_k + innerColA * 4 + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const uint cur1 = fastdiv(fastmodulo(start_k + innerColA * 4 + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint cur2 = fastmodulo(fastmodulo(start_k + innerColA * 4 + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint curC = layout == 0 ? cur2 : cur0;
|
||||
const uint curR = layout == 0 ? cur0 : cur1;
|
||||
const uint curS = layout == 0 ? cur1 : cur2;
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
|
||||
} else {
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// lds
|
||||
|
|
@ -279,106 +190,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
}
|
||||
}
|
||||
// ldg
|
||||
#pragma unroll
|
||||
for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
|
||||
if(vec_load){
|
||||
if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK < end_k){
|
||||
if constexpr (std::is_same_v<T, float>){
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0];
|
||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = tmp.x;
|
||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = tmp.y;
|
||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z;
|
||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w;
|
||||
} else {
|
||||
float2 tmp = reinterpret_cast<const float2 *>(&kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK])[0];
|
||||
const half *val = reinterpret_cast<const half *>(&tmp);
|
||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 0] = val[0];
|
||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + (BN+PAD)] = val[1];
|
||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 2*(BN+PAD)] = val[2];
|
||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + 3*(BN+PAD)] = val[3];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
||||
}
|
||||
}else{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i){
|
||||
if (by * BN + innerRowA + offset < param.k && innerColA * 4 + crs + BK + i < end_k){
|
||||
// float4 tmp = reinterpret_cast<float4 *>(¶m.weight[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i])[0];
|
||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = kernel[(by * BN + innerRowA + offset) * weightKOffset + innerColA * 4 + crs + BK + i];
|
||||
} else {
|
||||
smemweight[write_flag * (BN+PAD) * BK + weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
||||
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
|
||||
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
|
||||
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.u - param.p;
|
||||
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: bx * BM + innerRowA + offset, param.OW_fastdiv) * param.v - param.q;
|
||||
int inOffset = n * param.c * param.h * param.w ;
|
||||
if(vec_load){
|
||||
const uint cur0 = fastdiv(innerColA * 4 + crs + BK,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint curC = layout == 0 ? cur2 : cur0;
|
||||
const uint curR = layout == 0 ? cur0 : cur1;
|
||||
const uint curS = layout == 0 ? cur1 : cur2;
|
||||
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){
|
||||
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 2*(BM+PAD)] = tmp.z;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 3*(BM+PAD)] = tmp.w;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i){
|
||||
const uint cur0 = fastdiv(innerColA * 4 + crs + BK + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const uint cur1 = fastdiv(fastmodulo(innerColA * 4 + crs + BK + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint cur2 = fastmodulo(fastmodulo(innerColA * 4 + crs + BK + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const uint curC = layout == 0 ? cur2 : cur0;
|
||||
const uint curR = layout == 0 ? cur0 : cur1;
|
||||
const uint curS = layout == 0 ? cur1 : cur2;
|
||||
loadFilter<T, BN, rowStrideA, layout, vec_load, ksplit, PAD>
|
||||
(kernel, &smemweight[write_flag * (BN+PAD) * BK], by, innerRowA, innerColA, weightKOffset,
|
||||
crs+BK, end_k, param);
|
||||
|
||||
loadInput<BM, rowStrideA, layout, vec_load, ksplit, PAD>
|
||||
(input, &smeminput[write_flag * (BM+PAD) * BK], bx, innerRowA, innerColA,
|
||||
crs + BK, end_k, PQ, CHW, inChannelOffset, param);
|
||||
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
|
||||
} else {
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
write_flag ^= 1;
|
||||
|
|
|
|||
|
|
@ -359,6 +359,139 @@ __device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
|
|||
return address;
|
||||
}
|
||||
|
||||
template<typename T, const int BN, const int rowStrideA, const int layout,
|
||||
const bool vec_load, const int ksplit, const int PAD>
|
||||
__device__ __forceinline__ void loadFilter(const T * __restrict__ kernel,
|
||||
T * __restrict__ smemweight,
|
||||
const unsigned int by,
|
||||
const unsigned int innerRowA,
|
||||
const unsigned int innerColA,
|
||||
const unsigned int weightKOffset,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const param_t param){
|
||||
|
||||
const unsigned int weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4;
|
||||
const unsigned int kidx = start_k + innerColA * 4;
|
||||
#pragma unroll
|
||||
for (int offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
|
||||
const unsigned int nidx = by * BN + innerRowA + offset;
|
||||
if (vec_load) {
|
||||
if (nidx < param.k && kidx < end_k) {
|
||||
if constexpr (std::is_same_v<T, float>){
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&kernel[nidx * weightKOffset + kidx])[0];
|
||||
smemweight[weight_sts_addr + offset + 0] = tmp.x;
|
||||
smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y;
|
||||
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z;
|
||||
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w;
|
||||
} else { // read 4 halves
|
||||
float2 tmp = reinterpret_cast<const float2 *>(&kernel[nidx * weightKOffset + kidx])[0];
|
||||
const half *val = reinterpret_cast<const half *>(&tmp);
|
||||
smemweight[weight_sts_addr + offset + 0] = val[0];
|
||||
smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1];
|
||||
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2];
|
||||
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (nidx < param.k && kidx + i < end_k) {
|
||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[nidx * weightKOffset + kidx + i];
|
||||
} else {
|
||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<const int BM, const int rowStrideA, const int layout,
|
||||
const bool vec_load, const int ksplit, const int PAD>
|
||||
__device__ __forceinline__ void loadInput(const float * __restrict__ input,
|
||||
float * __restrict__ smeminput,
|
||||
const unsigned int bx,
|
||||
const unsigned int innerRowA,
|
||||
const unsigned int innerColA,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int PQ,
|
||||
const unsigned int CHW,
|
||||
const unsigned int inChannelOffset,
|
||||
const param_t param) {
|
||||
const unsigned int input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
|
||||
const unsigned int kidx = start_k + innerColA * 4;
|
||||
#pragma unroll
|
||||
for (unsigned int offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
||||
const unsigned int midx = bx * BM + innerRowA + offset;
|
||||
int n = (ksplit > 0) ? midx / PQ : blockIdx.z;
|
||||
const unsigned int npq_res = midx % PQ;
|
||||
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.u - param.p;
|
||||
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.v - param.q;
|
||||
const unsigned int inOffset = n * CHW;
|
||||
if (vec_load) {
|
||||
const unsigned int cur0 = fastdiv(kidx,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const unsigned int cur1 = fastdiv(fastmodulo(kidx,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const unsigned int cur2 = fastmodulo(fastmodulo(kidx,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const unsigned int curC = layout == 0 ? cur2 : cur0;
|
||||
const unsigned int curR = layout == 0 ? cur0 : cur1;
|
||||
const unsigned int curS = layout == 0 ? cur1 : cur2;
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx < end_k) {
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
|
||||
smeminput[input_sts_addr + offset + 0] = tmp.x;
|
||||
smeminput[input_sts_addr + offset + BM+PAD] = tmp.y;
|
||||
smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z;
|
||||
smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
const unsigned int cur0 = fastdiv(kidx + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const unsigned int cur1 = fastdiv(fastmodulo(kidx + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const unsigned int cur2 = fastmodulo(fastmodulo(kidx + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const unsigned int curC = layout == 0 ? cur2 : cur0;
|
||||
const unsigned int curR = layout == 0 ? cur0 : cur1;
|
||||
const unsigned int curS = layout == 0 ? cur1 : cur2;
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx + i < end_k) {
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
|
||||
} else {
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256
|
||||
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
|
|||
Loading…
Reference in New Issue