From 8236c453a503511e3fff6110ad6f6e0d77354547 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 8 Feb 2026 10:27:10 +0100 Subject: [PATCH] stage V loads through shmem --- .../vulkan-shaders/flash_attn.comp | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e4ca125eb4..6d85212d44 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -189,6 +189,7 @@ void main() { } if (K_LOAD_SHMEM != 0) { + barrier(); [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); @@ -293,6 +294,30 @@ void main() { } } + if (K_LOAD_SHMEM != 0) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSV / 4); + uint32_t c = (idx + tid) / (HSV / 4); + if (c < Bc) { + FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); + if (!KV_bounds_check || j * Bc + c < KV) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + V_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); +#else + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); +#endif + } + + kvsh[c * kvsh_stride + d] = V_Tf; + } + } + barrier(); + } + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; @@ -305,14 +330,19 @@ void main() { } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + FLOAT_TYPEV4 Vf; + if (K_LOAD_SHMEM != 0) { + Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + Vf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); #else - FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); + Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif + } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] += ACC_TYPEV4(Pf[r] * Vf); }