fixed a bug in calculating filter row index

This commit is contained in:
bssrdf 2025-11-09 17:30:08 -05:00
parent 36c0df7904
commit d2d814c156
1 changed files with 2 additions and 2 deletions

View File

@ -104,7 +104,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
// TODO: move some checks outside of loop?
if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
@ -302,7 +302,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
for (unsigned int i = 0; i < NUM_ITERS; i++){
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
// TODO : move some checks outside of the loop
if (thread_row < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c){
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);