change padding size to 1; added padding to input smem
This commit is contained in:
parent
3f99818925
commit
ac77b8d0e0
|
|
@ -25,15 +25,15 @@ static __global__ void reduce_f32(const float * __restrict__ x, float * __restri
|
|||
template<typename T, const int BM, const int BN, const int BK, const int WM, const int WN,
|
||||
const int WNITER, const int TM, const int TN, const int NUM_THREADS,
|
||||
// layout: 0, NHWC; 1, NCHW
|
||||
const int layout, const bool vec_load, const int ksplit, const int PAD=4>
|
||||
const int layout, const bool vec_load, const int ksplit, const int PAD=1>
|
||||
static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||
const T * __restrict__ kernel,
|
||||
float * __restrict__ output,
|
||||
const param_t param) {
|
||||
|
||||
// __shared__ char smem[4 * (TM*TN*NUM_THREADS <= (BM * BK + BK * (BN+PAD)) ? (BM * BK + BK * (BN+PAD)) : (TM*TN*NUM_THREADS))];
|
||||
__shared__ char smem[sizeof(float) * (TM*TN*NUM_THREADS) <= sizeof(float) * 2 * BM * BK + sizeof(T)*2*BK * (BN+PAD) ?
|
||||
sizeof(float)*2*BM*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)];
|
||||
__shared__ char smem[sizeof(float) * (TM*TN*NUM_THREADS) <= sizeof(float) * 2 * (BM+PAD) * BK + sizeof(T)*2*BK * (BN+PAD) ?
|
||||
sizeof(float)*2*(BM+PAD)*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)];
|
||||
// __shared__ float smeminput[2 * BM * BK];
|
||||
// __shared__ float smemweight[2 * BK * (BN+PAD)];
|
||||
T *smemweight = reinterpret_cast<T *>(smem);
|
||||
|
|
@ -175,7 +175,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
// int curS = ((tx % 2) * 4 % (param.s * param.c)) / param.c; // kernel r offset
|
||||
// int curC = ((tx % 2) * 4 % (param.s * param.c)) % param.c; // kernel s offset
|
||||
|
||||
const uint input_sts_addr = innerRowA + innerColA * BM * 4;
|
||||
const uint input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
|
||||
#pragma unroll
|
||||
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
||||
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
|
||||
|
|
@ -206,14 +206,14 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
|
||||
smeminput[input_sts_addr + offset + 0] = tmp.x;
|
||||
smeminput[input_sts_addr + offset + BM] = tmp.y;
|
||||
smeminput[input_sts_addr + offset + 2*BM] = tmp.z;
|
||||
smeminput[input_sts_addr + offset + 3*BM] = tmp.w;
|
||||
smeminput[input_sts_addr + offset + 0] = tmp.x;
|
||||
smeminput[input_sts_addr + offset + BM+PAD] = tmp.y;
|
||||
smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z;
|
||||
smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
smeminput[input_sts_addr + offset + i*BM] = 0.f;
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
|
|
@ -239,9 +239,9 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
smeminput[input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp];
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
|
||||
} else {
|
||||
smeminput[input_sts_addr + offset + i*BM] = 0.f;
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -390,8 +390,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx)
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < TM; ++i)
|
||||
input_frag[(subcrs + 1) % 2][wSubRowIdx * TM + i] = smeminput[load_flag * BM * BK +
|
||||
(subcrs + 1) * BM + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i];
|
||||
input_frag[(subcrs + 1) % 2][wSubRowIdx * TM + i] = smeminput[load_flag * (BM+PAD) * BK +
|
||||
(subcrs + 1) * (BM+PAD) + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i];
|
||||
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < 8; ++i)
|
||||
|
|
@ -497,14 +497,14 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
|
||||
smeminput[write_flag * BM * BK + input_sts_addr + offset + 0] = tmp.x;
|
||||
smeminput[write_flag * BM * BK + input_sts_addr + offset + BM] = tmp.y;
|
||||
smeminput[write_flag * BM * BK + input_sts_addr + offset + 2*BM] = tmp.z;
|
||||
smeminput[write_flag * BM * BK + input_sts_addr + offset + 3*BM] = tmp.w;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 0] = tmp.x;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + BM+PAD] = tmp.y;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 2*(BM+PAD)] = tmp.z;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + 3*(BM+PAD)] = tmp.w;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
|
|
@ -531,9 +531,9 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = input[inOffset + inOffsetTmp];
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
|
||||
} else {
|
||||
smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f;
|
||||
smeminput[write_flag * (BM+PAD) * BK + input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -553,7 +553,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx)
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < TM; ++i)
|
||||
input_frag[0][wSubRowIdx * TM + i] = smeminput[(load_flag ^ 1) * BM * BK +
|
||||
input_frag[0][wSubRowIdx * TM + i] = smeminput[(load_flag ^ 1) * (BM+PAD) * BK +
|
||||
input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i];
|
||||
#pragma unroll
|
||||
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx)
|
||||
|
|
|
|||
Loading…
Reference in New Issue