diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d78c727e53..6cf15b43bb 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2501,9 +2501,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) { if (hsv >= 192) { return 2; + } else if ((hsv | hsk) & 8) { + return 4; } else { return 8; } @@ -2535,9 +2537,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if ((hsv | hsk) & 8) { // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsv), 64}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 64}; } else { - return {get_fa_scalar_num_large_rows(hsv), 32}; + return {get_fa_scalar_num_large_rows(hsk, hsv), 32}; } } } @@ -7740,7 +7742,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsv); + const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -7871,7 +7873,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSV); + max_gqa = get_fa_scalar_num_large_rows(HSK, HSV); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ce8c068d7a..fd48d25475 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7859,6 +7859,9 @@ static std::vector> make_test_cases_perf() { } } + // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012 + test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) {