diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8807c3e2b6..6574955cf1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8840,7 +8840,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. - bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; + bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16; vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc, mask != nullptr, use_mask_opt, logit_softcap != 0); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp index 8c92c1adcd..0e41770806 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp @@ -33,6 +33,61 @@ layout (push_constant) uniform parameter { shared float minsh[NUM_SUBGROUPS]; shared float maxsh[NUM_SUBGROUPS]; +float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF); + +void loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) { + const uint tid = gl_LocalInvocationIndex; + + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) { + uint j0 = (i + tid) % (Bc / 4); + uint j1 = (i + tid) / (Bc / 4); + + j0 *= 4; + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + if (!need_bounds_check || j0 + 3 < nem0) { + vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]); + [[unroll]] for (int c = 0; c < 4; ++c) { + min_v = min(min_v, f[c]); + max_v = max(max_v, f[c]); + } + } else { + [[unroll]] for (int c = 0; c < 4; ++c) { + if (j0 + c < nem0) { + float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]); + min_v = min(min_v, f); + max_v = max(max_v, f); + } + } + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } +} + // For each Br x Bc block of the mask (input) buffer, read all values and check // if it's all -inf or all zero. Write out a two-bit code indicating which it is // (or zero for neither). Each workgroup processes 16 tiles and writes out a @@ -48,50 +103,15 @@ void main() { const uint i2 = gl_WorkGroupID.z % nem2; const uint i3 = gl_WorkGroupID.z / nem2; - float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF); - uint result = 0; // Fast path for fully in-bounds blocks where we can do f16vec4 loads if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 && ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) { - [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { - float min_v = FLT_MAX_OVER_2; - float max_v = -FLT_MAX_OVER_2; - [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) { - uint j0 = (i + tid) % (Bc / 4); - uint j1 = (i + tid) / (Bc / 4); - - j0 *= 4; - j0 += (i0 * 16 + block_x) * Bc; - j1 += i1 * Br; - - vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]); - [[unroll]] for (int c = 0; c < 4; ++c) { - min_v = min(min_v, f[c]); - max_v = max(max_v, f[c]); - } - } - min_v = subgroupMin(min_v); - max_v = subgroupMax(max_v); - if (gl_SubgroupInvocationID == 0) { - minsh[gl_SubgroupID] = min_v; - maxsh[gl_SubgroupID] = max_v; - } - barrier(); - if (tid == 0) { - [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { - min_v = min(min_v, minsh[i]); - max_v = max(max_v, maxsh[i]); - } - if (max_v <= -FLT_MAX_OVER_2) { - result |= 1 << (2*block_x); - } - if (min_v == 0.0f && max_v == 0.0f) { - result |= 2 << (2*block_x); - } - } - barrier(); + if ((i0 + 1) * 16 * Bc <= nem0) { + loadvec4(result, i0, i1, i2, i3, false); + } else { + loadvec4(result, i0, i1, i2, i3, true); } } else { [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {