cuda/hip: fix loop unrolling in ssm-conv (#20369)
This commit is contained in:
parent
00de615345
commit
9ef7523ee9
|
|
@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
|||
int row = tid / load_cols;
|
||||
int col = tid % load_cols;
|
||||
#pragma unroll
|
||||
for (int idx = tid; idx < total_elems; idx += split_d_inner) {
|
||||
for (int idx = 0; idx < total_elems; idx += split_d_inner) {
|
||||
if (row < (int)split_d_inner) {
|
||||
smem[row * n_cols + col] = x_block[row * stride_x + col];
|
||||
}
|
||||
|
|
@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
|||
col += split_d_inner;
|
||||
row += col / load_cols;
|
||||
col = col % load_cols;
|
||||
if (idx >= total_elems - tid - split_d_inner) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue