From 3c2088121cb7d7944a8b6060d121d1a7e4edb7ad Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 7 Feb 2026 07:50:56 +0100 Subject: [PATCH] add padding to mask shmem buffer --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index ec4a831fd6..24589dfe7c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -47,7 +47,8 @@ const uint32_t tmpsh_size = row_split == 1 ? num_subgroups * D_split : 1; shared float tmpsh[tmpsh_size]; shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; -shared FLOAT_TYPE masksh[Bc][Br]; +const uint32_t masksh_stride = Br + 1; +shared FLOAT_TYPE masksh[Bc * masksh_stride]; const uint qfstride = HSK / 4 + 1; shared FLOAT_TYPEV4 Qf[Br * qfstride]; @@ -153,10 +154,10 @@ void main() { if (idx + tid < Bc * Br) { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); - masksh[c][r] = m; + masksh[c * masksh_stride + r] = m; max_mask = max(max_mask, float(m)); } else { - masksh[c][r] = FLOAT_TYPE(0); + masksh[c * masksh_stride + r] = FLOAT_TYPE(0); } } } @@ -222,7 +223,7 @@ void main() { if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - FLOAT_TYPE mvf = masksh[c * cols_per_iter + col_tid][tile_row(r)]; + FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)]; Sf[r][c] += slope[r]*mvf; }