fixed a bug of not bound checking batch dimension
This commit is contained in:
parent
5e1352cb60
commit
d6d24487c2
|
|
@ -180,7 +180,8 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d && curC < param.c && kidx < end_k){
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d &&
|
||||
n < param.n && curC < param.c && kidx < end_k){
|
||||
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
|
||||
} else{
|
||||
|
|
@ -249,7 +250,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
|
||||
const int curC = curIdx.x;
|
||||
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d
|
||||
&& curC < param.c && kidx < end_k){
|
||||
&& n < param.n && curC < param.c && kidx < end_k){
|
||||
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
|
||||
} else{
|
||||
|
|
|
|||
Loading…
Reference in New Issue