cuda/hip: fix loop unrolling in ssm-conv (#20369)
This commit is contained in:
parent
00de615345
commit
9ef7523ee9
|
|
@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
||||||
int row = tid / load_cols;
|
int row = tid / load_cols;
|
||||||
int col = tid % load_cols;
|
int col = tid % load_cols;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int idx = tid; idx < total_elems; idx += split_d_inner) {
|
for (int idx = 0; idx < total_elems; idx += split_d_inner) {
|
||||||
if (row < (int)split_d_inner) {
|
if (row < (int)split_d_inner) {
|
||||||
smem[row * n_cols + col] = x_block[row * stride_x + col];
|
smem[row * n_cols + col] = x_block[row * stride_x + col];
|
||||||
}
|
}
|
||||||
|
|
@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
||||||
col += split_d_inner;
|
col += split_d_inner;
|
||||||
row += col / load_cols;
|
row += col / load_cols;
|
||||||
col = col % load_cols;
|
col = col % load_cols;
|
||||||
|
if (idx >= total_elems - tid - split_d_inner) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue