cuda/hip: fix loop unrolling in ssm-conv (#20369)

This commit is contained in:
uvos 2026-03-11 06:04:32 +01:00 committed by GitHub
parent 00de615345
commit 9ef7523ee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 1 deletions

View File

@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
int row = tid / load_cols;
int col = tid % load_cols;
#pragma unroll
for (int idx = tid; idx < total_elems; idx += split_d_inner) {
for (int idx = 0; idx < total_elems; idx += split_d_inner) {
if (row < (int)split_d_inner) {
smem[row * n_cols + col] = x_block[row * stride_x + col];
}
@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
col += split_d_inner;
row += col / load_cols;
col = col % load_cols;
if (idx >= total_elems - tid - split_d_inner) {
break;
}
}
__syncthreads();