add missing batch index bounds check
This commit is contained in:
parent
a660d4d45d
commit
fac6f0adc3
|
|
@ -146,7 +146,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
||||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||||
curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){
|
||||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
|
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
|
||||||
} else{
|
} else{
|
||||||
|
|
@ -214,7 +214,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
||||||
int curH = posh_ori + curR * param.d_h; // input h
|
int curH = posh_ori + curR * param.d_h; // input h
|
||||||
int curW = posw_ori + curS * param.d_w; // input w
|
int curW = posw_ori + curS * param.d_w; // input w
|
||||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||||
curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){
|
||||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
|
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
|
||||||
} else{
|
} else{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue