From 4819fd3014a1cdb943fd6ff10b47c35a91cae020 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 8 Feb 2026 15:41:32 +0100 Subject: [PATCH] dynamic subgroups for intel --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 23 ++-- .../vulkan-shaders/flash_attn.comp | 112 ++++++++++++------ .../vulkan-shaders/flash_attn_base.glsl | 2 + 3 files changed, 95 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 16746e75d9..5b82d60b64 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3191,6 +3191,8 @@ static void ggml_vk_load_shaders(vk_device& device) { return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1}; }; + const bool disable_subgroups = device->vendor_id == VK_VENDOR_ID_INTEL; + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we @@ -3206,10 +3208,16 @@ static void ggml_vk_load_shaders(vk_device& device) { wg_size = ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); break; case FA_COOPMAT1: - wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc + if (disable_subgroups) { + wg_size = 128; + } else { + wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc + } break; default: - if (device->subgroup_size > 32 && rows_cols[0] < 4) { + if (disable_subgroups) { + wg_size = 128; + } else if (device->subgroup_size > 32 && rows_cols[0] < 4) { wg_size = device->subgroup_size * 2; } else { wg_size = device->subgroup_size * 4; @@ -3227,7 +3235,8 @@ static void ggml_vk_load_shaders(vk_device& device) { // AMD prefers loading K directly from global memory const uint32_t shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0; - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, shmem_staging, flags}; + const uint32_t subgroup_size = disable_subgroups ? 0xFFFFFFFF : device->subgroup_size; + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, subgroup_size, shmem_staging, flags}; }; #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ @@ -3243,15 +3252,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!disable_subgroups && (FAPATH!=FA_COOPMAT2)), ((!disable_subgroups && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \ } \ } \ } \ diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 7324d770a0..0ed6a390c5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -43,7 +43,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in ACC_ return elem; } -const uint32_t tmpsh_size = row_split == 1 ? num_subgroups * D_split : 1; +// If SubGroupSize is set to 0xFFFFFFFF then only use shmem reductions +const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : 1) : WorkGroupSize; shared float tmpsh[tmpsh_size]; shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; @@ -67,6 +68,7 @@ void main() { const uint32_t tid = gl_LocalInvocationIndex; const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup; const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; @@ -359,20 +361,33 @@ void main() { float rowmaxf = Mf[r]; // Compute max across the row - [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { - rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s)); - } - if (row_split == 1) { - // Reduce inside workgroup with shmem - barrier(); - if (gl_SubgroupInvocationID == d_tid) { - tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf; + if (SubGroupSize != SUBGROUPS_DISABLED) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s)); } - barrier(); - rowmaxf = tmpsh[d_tid]; - [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { - rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]); + if (row_split == 1) { + // Reduce inside workgroup with shmem + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf; + } + barrier(); + rowmaxf = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]); + } } + } else { + barrier(); + tmpsh[tid] = rowmaxf; + barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]); + } + barrier(); + } + rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid]; } float Moldf = Mf[r]; @@ -385,37 +400,64 @@ void main() { Lf[r] = eMf*Lf[r]; // Compute sum across the row - [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { - Lf[r] += subgroupShuffleXor(Lf[r], s); - } - if (row_split == 1) { - barrier(); - if (gl_SubgroupInvocationID == d_tid) { - tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r]; + if (SubGroupSize != SUBGROUPS_DISABLED) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Lf[r] += subgroupShuffleXor(Lf[r], s); } - barrier(); - Lf[r] = tmpsh[d_tid]; - [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { - Lf[r] += tmpsh[s * D_split + d_tid]; + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r]; + } + barrier(); + Lf[r] = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Lf[r] += tmpsh[s * D_split + d_tid]; + } } + } else { + barrier(); + tmpsh[tid] = Lf[r]; + barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s]; + } + barrier(); + } + Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid]; } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { Of[r][d] = ACC_TYPE(eMf) * Of[r][d]; - [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { - Of[r][d] += subgroupShuffleXor(Of[r][d], s); - } - if (row_split == 1) { - barrier(); - if (gl_SubgroupInvocationID == d_tid) { - tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; + if (SubGroupSize != SUBGROUPS_DISABLED) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Of[r][d] += subgroupShuffleXor(Of[r][d], s); } - barrier(); - Of[r][d] = tmpsh_accv4[d_tid]; - [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { - Of[r][d] += tmpsh_accv4[s * D_split + d_tid]; + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; + } + barrier(); + Of[r][d] = tmpsh_accv4[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Of[r][d] += tmpsh_accv4[s * D_split + d_tid]; + } } + } else { + barrier(); + tmpsh_accv4[tid] = Of[r][d]; + barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + Of[r][d] += tmpsh_accv4[tid ^ s]; + tmpsh_accv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpsh_accv4[row_tid * threads_per_rowgroup + d_tid]; } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 0a077f6876..857e9faff4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -66,6 +66,8 @@ layout (push_constant) uniform parameter { #define SINK_ENABLE_BIT (1<<24) #define N_LOG2_MASK 0xFFFF +#define SUBGROUPS_DISABLED 0xFFFFFFFF + layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};