From ca5ec63cfb3fcd2e88e771b19ed96cbc53db3ea0 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 7 Feb 2026 17:15:55 +0100 Subject: [PATCH] cache q values into registers for KQ --- .../ggml-vulkan/vulkan-shaders/flash_attn.comp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 24589dfe7c..e6a1de3f70 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -184,11 +184,17 @@ void main() { } } - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 Q_cache[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Q_cache[r] = Qf[tile_row(r) * qfstride + d * D_split + d_tid]; } - [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; @@ -198,7 +204,7 @@ void main() { FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qfstride + d * D_split + d_tid], K_Tf)); + Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); } } }