add padding to mask shmem buffer

This commit is contained in:
Ruben Ortlam 2026-02-07 07:50:56 +01:00
parent 07afb5128f
commit 3c2088121c
1 changed files with 5 additions and 4 deletions

View File

@ -47,7 +47,8 @@ const uint32_t tmpsh_size = row_split == 1 ? num_subgroups * D_split : 1;
shared float tmpsh[tmpsh_size];
shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size];
shared FLOAT_TYPE masksh[Bc][Br];
const uint32_t masksh_stride = Br + 1;
shared FLOAT_TYPE masksh[Bc * masksh_stride];
const uint qfstride = HSK / 4 + 1;
shared FLOAT_TYPEV4 Qf[Br * qfstride];
@ -153,10 +154,10 @@ void main() {
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
masksh[c * masksh_stride + r] = m;
max_mask = max(max_mask, float(m));
} else {
masksh[c][r] = FLOAT_TYPE(0);
masksh[c * masksh_stride + r] = FLOAT_TYPE(0);
}
}
}
@ -222,7 +223,7 @@ void main() {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE mvf = masksh[c * cols_per_iter + col_tid][tile_row(r)];
FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)];
Sf[r][c] += slope[r]*mvf;
}