From 50a420e0444dfff5c604e319038f591288e3e35e Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 8 Feb 2026 08:32:00 +0100 Subject: [PATCH] fuse lf accumulation, pf and v accumulation into a loop --- .../vulkan-shaders/flash_attn.comp | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e6a1de3f70..e641debe3c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -236,7 +236,6 @@ void main() { } } - FLOAT_TYPE Pf[rows_per_thread][cols_per_thread]; float eMf[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { float rowmaxf = NEG_FLT_MAX_OVER_2; @@ -252,21 +251,8 @@ void main() { // P = e^(S - M) // eM = e^(Mold - M) Mf[r] = max(rowmaxf, Moldf); - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Pf[r][c] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r])); - } eMf[r] = exp(Moldf - Mf[r]); - - // Compute sum across row of P - float rowsumf = 0.0; - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; - } - rowsumf += Pf[r][c]; - } - - Lf[r] = eMf[r]*Lf[r] + rowsumf; + Lf[r] = eMf[r]*Lf[r]; } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { @@ -279,6 +265,13 @@ void main() { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } + + FLOAT_TYPE Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r])); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); @@ -289,7 +282,7 @@ void main() { FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] += ACC_TYPEV4(Pf[r][c] * Vf); + Of[r][d] += ACC_TYPEV4(Pf[r] * Vf); } } }