fixed missing dilation

This commit is contained in:
bssrdf 2025-10-14 11:12:55 -04:00
parent 2237722056
commit 3e2f722d11
1 changed files with 8 additions and 8 deletions

View File

@ -195,8 +195,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR; // input h
const int curW = posw_ori + curS; // input w
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 < end_k){
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
int inOffsetTmp = layout == 0 ?
@ -229,8 +229,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
const uint curC = layout == 0 ? cur2 : cur0;
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR; // input h
const int curW = posw_ori + curS; // input w
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && start_k + innerColA * 4 + i < end_k){
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
int inOffsetTmp = layout == 0 ?
@ -482,8 +482,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR; // input h
const int curW = posw_ori + curS; // input w
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK < end_k){
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
int inOffsetTmp = layout == 0 ?
@ -517,8 +517,8 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
const uint curR = layout == 0 ? cur0 : cur1;
const uint curS = layout == 0 ? cur1 : cur2;
const int curH = posh_ori + curR; // input h
const int curW = posw_ori + curS; // input w
const int curH = posh_ori + curR * param.d_h; // input h
const int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && innerColA * 4 + crs + BK + i < end_k){
// int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
int inOffsetTmp = layout == 0 ?