From dd92b1f8d530b7227dde29adb7b406010b1eb8b2 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 14 Feb 2026 06:45:58 +0100 Subject: [PATCH] fix regressions --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 +-- .../vulkan-shaders/flash_attn.comp | 65 ++++++++++++++----- 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 801da7c7f8..b3212ff139 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2806,9 +2806,7 @@ static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint case FA_COOPMAT1: return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc default: - if (device->vendor_id == VK_VENDOR_ID_INTEL) { - return 128; - } else if (subgroup_size > 32 && Br < 4) { + if (subgroup_size > 32 && Br < 4) { return subgroup_size * 2; } else { return subgroup_size * 4; @@ -3222,8 +3220,6 @@ static void ggml_vk_load_shaders(vk_device& device) { return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1}; }; - const bool disable_subgroups = device->vendor_id == VK_VENDOR_ID_INTEL; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we @@ -3259,7 +3255,7 @@ static void ggml_vk_load_shaders(vk_device& device) { bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ uint32_t flags = fa.first.flags; \ - bool fa_ds = path == FA_SCALAR && disable_subgroups; \ + bool fa_ds = fa_disable_subgroups(device, path); \ uint32_t fa_sgs = fa_subgroup_size(device, path); \ if (path == FAPATH) { \ if (aligned) { \ diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 563a2bcbdc..8974593e9f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -19,7 +19,7 @@ const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; -const uint32_t row_split = (Br < 4) ? 1 : 4; +const uint32_t row_split = (Br < 4 || HSK <= 64) ? 1 : 4; const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; @@ -205,32 +205,61 @@ void main() { barrier(); } - [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { - FLOAT_TYPEV4 Q_cache[rows_per_thread]; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid]; - } + // More d iterations means Q register caching becomes relevant + // Few iterations means the additional registers needed are worse than the speed-up from caching + if (HSK_per_thread / 4 > 4) { + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 Q_cache[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid]; + } + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + + FLOAT_TYPEV4 K_Tf; + if (SHMEM_STAGING != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); + } + } + } + } else { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - FLOAT_TYPEV4 K_Tf; - if (SHMEM_STAGING != 0) { - K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 K_Tf; + if (SHMEM_STAGING != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else - K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif - } - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf)); + } } } }