diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e5dcd3cbda..82933ae033 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8422,6 +8422,8 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; + const uint32_t tmpsh = (Bc / MatBc) * sizeof(float); + const uint32_t qstride = hsk_pad / 4 + 2; const uint32_t Qf = Br * qstride * f16vec4; @@ -8438,7 +8440,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t slope = Br * acctype; - const uint32_t total_size = Qf + Psh + sfsh + ksh + slope; + const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 914f131c96..0735f67854 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -130,6 +130,7 @@ void main() { if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + float max_mask = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; @@ -137,12 +138,25 @@ void main() { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); masksh[c][r] = m; + max_mask = max(max_mask, m); } else { masksh[c][r] = float(0); } } } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } float Sf[Br][cols_per_thread]; @@ -260,6 +274,9 @@ void main() { barrier(); } + // prevent race on tmpsh + barrier(); + // reduce across threads [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index b317773823..19630972da 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -42,6 +42,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } +shared float tmpsh[row_split]; + const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; @@ -213,6 +215,19 @@ void main() { } } } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 39f0c4d23b..853f17fa16 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -176,7 +176,14 @@ void main() { tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t + coopmat mvmax; + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } else { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); // Don't clamp against nem1 when GQA is enabled @@ -184,7 +191,14 @@ void main() { tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + coopmat mvmax; + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } } }