add padding to mask shmem buffer
This commit is contained in:
parent
07afb5128f
commit
3c2088121c
|
|
@ -47,7 +47,8 @@ const uint32_t tmpsh_size = row_split == 1 ? num_subgroups * D_split : 1;
|
|||
shared float tmpsh[tmpsh_size];
|
||||
shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size];
|
||||
|
||||
shared FLOAT_TYPE masksh[Bc][Br];
|
||||
const uint32_t masksh_stride = Br + 1;
|
||||
shared FLOAT_TYPE masksh[Bc * masksh_stride];
|
||||
|
||||
const uint qfstride = HSK / 4 + 1;
|
||||
shared FLOAT_TYPEV4 Qf[Br * qfstride];
|
||||
|
|
@ -153,10 +154,10 @@ void main() {
|
|||
if (idx + tid < Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c][r] = m;
|
||||
masksh[c * masksh_stride + r] = m;
|
||||
max_mask = max(max_mask, float(m));
|
||||
} else {
|
||||
masksh[c][r] = FLOAT_TYPE(0);
|
||||
masksh[c * masksh_stride + r] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -222,7 +223,7 @@ void main() {
|
|||
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
FLOAT_TYPE mvf = masksh[c * cols_per_iter + col_tid][tile_row(r)];
|
||||
FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)];
|
||||
|
||||
Sf[r][c] += slope[r]*mvf;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue