diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1872385b65..d8a2dc098e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2759,11 +2759,11 @@ static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) { if (hsv >= 192) { - return 2; - } else if ((hsv | hsk) & 8 || small_cache) { - return 4; - } else { return 8; + } else if ((hsv | hsk) & 8 || small_cache) { + return 8; + } else { + return 16; } } @@ -2789,13 +2789,7 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if (small_rows) { return {scalar_flash_attention_num_small_rows, 64}; } else { - if ((hsv | hsk) & 8) { - // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter - // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; - } else { - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32}; - } + return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; } } @@ -3213,7 +3207,11 @@ static void ggml_vk_load_shaders(vk_device& device) { wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc break; default: - wg_size = device->subgroup_size * 4; + if (device->subgroup_size > 32 && rows_cols[0] < 4) { + wg_size = device->subgroup_size * 2; + } else { + wg_size = device->subgroup_size * 4; + } break; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 49d50ed854..223b58d8ef 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -309,10 +309,10 @@ void main() { tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf; } barrier(); - rowmaxf = max(max(max(tmpsh[0 * D_split + d_tid], - tmpsh[1 * D_split + d_tid]), - tmpsh[2 * D_split + d_tid]), - tmpsh[3 * D_split + d_tid]); + rowmaxf = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]); + } } float Moldf = Mf[r]; @@ -334,10 +334,10 @@ void main() { tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r]; } barrier(); - Lf[r] = tmpsh[0 * D_split + d_tid] + - tmpsh[1 * D_split + d_tid] + - tmpsh[2 * D_split + d_tid] + - tmpsh[3 * D_split + d_tid]; + Lf[r] = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Lf[r] += tmpsh[s * D_split + d_tid]; + } } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { @@ -352,10 +352,10 @@ void main() { tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; } barrier(); - Of[r][d] = tmpsh_accv4[0 * D_split + d_tid] + - tmpsh_accv4[1 * D_split + d_tid] + - tmpsh_accv4[2 * D_split + d_tid] + - tmpsh_accv4[3 * D_split + d_tid]; + Of[r][d] = tmpsh_accv4[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Of[r][d] += tmpsh_accv4[s * D_split + d_tid]; + } } } }