From 02ccf81496b25365c42006c81fdfad7dd481c303 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 14 Feb 2026 11:43:31 +0100 Subject: [PATCH] tmpsh size fix --- .../vulkan-shaders/flash_attn.comp | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 8974593e9f..d8c4f4134d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -35,7 +35,7 @@ layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; // If SubGroupSize is set to 0xFFFFFFFF then only use shmem reductions -const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : 1) : WorkGroupSize; +const uint32_t tmpsh_size = (SubGroupSize != SUBGROUPS_DISABLED) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize; shared float tmpsh[tmpsh_size]; shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; @@ -143,7 +143,7 @@ void main() { if (mask_opt_bits != MASK_OPT_ALL_ZERO) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - float max_mask = NEG_FLT_MAX_OVER_2; + float max_mask = NEG_FLT_MAX_OVER_2; barrier(); [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; @@ -152,25 +152,25 @@ void main() { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); masksh[c * masksh_stride + r] = m; - max_mask = max(max_mask, float(m)); + max_mask = max(max_mask, float(m)); } else { masksh[c * masksh_stride + r] = FLOAT_TYPE(0); } } } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; - } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } }