From e880cb2e0d66da251a9c2373d26e4093c1102300 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Fri, 13 Mar 2026 08:18:33 +0100 Subject: [PATCH] Revert "move kv shmem staging to function" This reverts commit 0349025db8249f5468cc273507d901c7ff396a3f. --- .../vulkan-shaders/flash_attn.slang | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang index 3115bd5fb5..30ab494bc0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang @@ -262,23 +262,6 @@ void gqaStore(const in uint32_t r, const in uint3 data_ov4[o_offset + offset] = vector(elems); } -void stageKVToShmem(L loader, uint j, uint HS, uint tid, uint KV) { - GroupMemoryBarrierWithGroupSync(); - [unroll] for (uint idx = 0; idx < Bc * HS / 4; idx += WorkGroupSize) { - uint d = (idx + tid) % (HS / 4); - uint c = (idx + tid) / (HS / 4); - if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) { - vector K_Tf = vector(0); - if (!KV_bounds_check || j * Bc + c < KV) { - K_Tf = loader.load(j * Bc + c, d); - } - - kvsh[c * kvsh_stride + d] = K_Tf; - } - } - GroupMemoryBarrierWithGroupSync(); -} - [shader("compute")] [numthreads(WorkGroupSize, 1, 1)] void main( @@ -436,7 +419,20 @@ void main( } if (SHMEM_STAGING != 0) { - stageKVToShmem(kloader, j, HSK, tid, idcs.KV); + GroupMemoryBarrierWithGroupSync(); + [unroll] for (uint idx = 0; idx < Bc * HSK / 4; idx += WorkGroupSize) { + uint d = (idx + tid) % (HSK / 4); + uint c = (idx + tid) / (HSK / 4); + if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) { + vector K_Tf = vector(0); + if (!KV_bounds_check || j * Bc + c < idcs.KV) { + K_Tf = kloader.load(j * Bc + c, d); + } + + kvsh[c * kvsh_stride + d] = K_Tf; + } + } + GroupMemoryBarrierWithGroupSync(); } // More d iterations means Q register caching becomes relevant @@ -537,7 +533,20 @@ void main( } if (SHMEM_STAGING != 0) { - stageKVToShmem(vloader, j, HSV, tid, idcs.KV); + GroupMemoryBarrierWithGroupSync(); + [unroll] for (uint idx = 0; idx < Bc * HSV / 4; idx += WorkGroupSize) { + uint d = (idx + tid) % (HSV / 4); + uint c = (idx + tid) / (HSV / 4); + if (idx + WorkGroupSize <= Bc * HSV / 4 || c < Bc) { + vector V_Tf = vector(0); + if (!KV_bounds_check || j * Bc + c < idcs.KV) { + V_Tf = vloader.load(j * Bc + c, d); + } + + kvsh[c * kvsh_stride + d] = V_Tf; + } + } + GroupMemoryBarrierWithGroupSync(); } [unroll] for (uint c = 0; c < cols_per_thread; ++c) {