From 828b7e9bb191ef74a7c4052c26c42a5f55349a01 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 5 Feb 2026 12:51:59 +0100 Subject: [PATCH] use row_split when Br >= 4, change reductions to use shared memory if row_split == 1 --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 12 +-- .../vulkan-shaders/flash_attn.comp | 73 +++++++++++-------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 8 +- 3 files changed, 52 insertions(+), 41 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 114992da08..0ca00aa4b5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3213,7 +3213,7 @@ static void ggml_vk_load_shaders(vk_device& device) { wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc break; default: - wg_size = scalar_flash_attention_workgroup_size; + wg_size = device->subgroup_size * 4; break; } @@ -3243,15 +3243,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -16068,7 +16068,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * ggml_vk_print_graph_origin(tensor, done); } - if (avg_err > 0.5 || std::isnan(avg_err)) { + if (avg_err > 0.01 || std::isnan(avg_err)) { std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; if (src0 != nullptr) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index af0b371cfa..49d50ed854 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -19,10 +19,11 @@ const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; -const uint32_t row_split = 4; +const uint32_t row_split = (Br < 4) ? 1 : 4; const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; +const uint32_t num_subgroups = WorkGroupSize / SubGroupSize; layout (binding = 0) readonly buffer Q {float data_q[];}; @@ -42,8 +43,10 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_ return elem; } -shared float tmpsh[WorkGroupSize]; -shared vec4 tmpshv4[WorkGroupSize]; +const uint32_t tmpsh_reduction_size = row_split == 1 ? num_subgroups * D_split : 0; +const uint32_t tmpsh_size = tmpsh_reduction_size > 4 ? tmpsh_reduction_size : 4; +shared float tmpsh[tmpsh_size]; +shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; shared FLOAT_TYPE masksh[Bc][Br]; @@ -279,7 +282,7 @@ void main() { FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] += ACC_TYPE(Pf[r][c] * Vf); + Of[r][d] += ACC_TYPEV4(Pf[r][c] * Vf); } } } @@ -293,57 +296,67 @@ void main() { // reduce across threads [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - float rowmaxf, eMf; + float rowmaxf = Mf[r]; - tmpsh[tid] = Mf[r]; // Compute max across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s)); + } + if (row_split == 1) { + // Reduce inside workgroup with shmem + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf; } barrier(); + rowmaxf = max(max(max(tmpsh[0 * D_split + d_tid], + tmpsh[1 * D_split + d_tid]), + tmpsh[2 * D_split + d_tid]), + tmpsh[3 * D_split + d_tid]); } - rowmaxf = tmpsh[d_tid]; - barrier(); float Moldf = Mf[r]; // M = max(rowmax, Mold) // eM = e^(Mold - M) Mf[r] = max(rowmaxf, Moldf); - eMf = exp(Moldf - Mf[r]); + float eMf = exp(Moldf - Mf[r]); Lf[r] = eMf*Lf[r]; - tmpsh[tid] = Lf[r]; - // Compute sum across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Lf[r] += subgroupShuffleXor(Lf[r], s); + } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r]; } barrier(); + Lf[r] = tmpsh[0 * D_split + d_tid] + + tmpsh[1 * D_split + d_tid] + + tmpsh[2 * D_split + d_tid] + + tmpsh[3 * D_split + d_tid]; } - Lf[r] = tmpsh[d_tid]; - barrier(); [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] = ACC_TYPE(eMf) * Of[r][d]; - tmpshv4[tid] = Of[r][d]; - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - Of[r][d] += ACC_TYPEV4(tmpshv4[tid + s]); - tmpshv4[tid] = Of[r][d]; + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Of[r][d] += subgroupShuffleXor(Of[r][d], s); + } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; } barrier(); + Of[r][d] = tmpsh_accv4[0 * D_split + d_tid] + + tmpsh_accv4[1 * D_split + d_tid] + + tmpsh_accv4[2 * D_split + d_tid] + + tmpsh_accv4[3 * D_split + d_tid]; } - Of[r][d] = ACC_TYPEV4(tmpshv4[d_tid]); - barrier(); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index b244f9fa5f..7fbe45d33f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -630,12 +630,10 @@ void process_shaders() { // flash attention for (const bool& f16acc : {false, true}) { - if (!fp16 && f16acc) continue; - std::map fa_base_dict = base_dict; - fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; - fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; - if (f16acc) { + fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4"; + if (fp16 && f16acc) { fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; }