diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang index 30ab494bc0..3115bd5fb5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.slang @@ -262,6 +262,23 @@ void gqaStore(const in uint32_t r, const in uint3 data_ov4[o_offset + offset] = vector(elems); } +void stageKVToShmem(L loader, uint j, uint HS, uint tid, uint KV) { + GroupMemoryBarrierWithGroupSync(); + [unroll] for (uint idx = 0; idx < Bc * HS / 4; idx += WorkGroupSize) { + uint d = (idx + tid) % (HS / 4); + uint c = (idx + tid) / (HS / 4); + if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) { + vector K_Tf = vector(0); + if (!KV_bounds_check || j * Bc + c < KV) { + K_Tf = loader.load(j * Bc + c, d); + } + + kvsh[c * kvsh_stride + d] = K_Tf; + } + } + GroupMemoryBarrierWithGroupSync(); +} + [shader("compute")] [numthreads(WorkGroupSize, 1, 1)] void main( @@ -419,20 +436,7 @@ void main( } if (SHMEM_STAGING != 0) { - GroupMemoryBarrierWithGroupSync(); - [unroll] for (uint idx = 0; idx < Bc * HSK / 4; idx += WorkGroupSize) { - uint d = (idx + tid) % (HSK / 4); - uint c = (idx + tid) / (HSK / 4); - if (idx + WorkGroupSize <= Bc * HSK / 4 || c < Bc) { - vector K_Tf = vector(0); - if (!KV_bounds_check || j * Bc + c < idcs.KV) { - K_Tf = kloader.load(j * Bc + c, d); - } - - kvsh[c * kvsh_stride + d] = K_Tf; - } - } - GroupMemoryBarrierWithGroupSync(); + stageKVToShmem(kloader, j, HSK, tid, idcs.KV); } // More d iterations means Q register caching becomes relevant @@ -533,20 +537,7 @@ void main( } if (SHMEM_STAGING != 0) { - GroupMemoryBarrierWithGroupSync(); - [unroll] for (uint idx = 0; idx < Bc * HSV / 4; idx += WorkGroupSize) { - uint d = (idx + tid) % (HSV / 4); - uint c = (idx + tid) / (HSV / 4); - if (idx + WorkGroupSize <= Bc * HSV / 4 || c < Bc) { - vector V_Tf = vector(0); - if (!KV_bounds_check || j * Bc + c < idcs.KV) { - V_Tf = vloader.load(j * Bc + c, d); - } - - kvsh[c * kvsh_stride + d] = V_Tf; - } - } - GroupMemoryBarrierWithGroupSync(); + stageKVToShmem(vloader, j, HSV, tid, idcs.KV); } [unroll] for (uint c = 0; c < cols_per_thread; ++c) {