From 449ec2ab0751fc713fe338da2ced153125b5c674 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 5 Feb 2026 09:26:38 -0600 Subject: [PATCH] vulkan: Preprocess FA mask to detect all-neg-inf and all-zero. (#19281) Write out a 2-bit code per block and avoid loading the mask when it matches these two common cases. Apply this optimization when the mask is relatively large (i.e. prompt processing). --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 109 +++++++++++--- .../vulkan-shaders/flash_attn.comp | 39 ++--- .../vulkan-shaders/flash_attn_base.glsl | 6 + .../vulkan-shaders/flash_attn_cm1.comp | 112 +++++++------- .../vulkan-shaders/flash_attn_cm2.comp | 63 ++++---- .../vulkan-shaders/flash_attn_mask_opt.comp | 142 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 + tests/test-backend-ops.cpp | 10 +- 8 files changed, 360 insertions(+), 123 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ff9cb7355c..4357da24d4 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -402,18 +402,19 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc) - : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {} + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt) + : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {} uint32_t HSK, HSV; bool small_rows, small_cache; FaCodePath path; bool aligned; bool f32acc; + bool use_mask_opt; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) < - std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc); + return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) < + std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt); } }; @@ -820,6 +821,8 @@ struct vk_device_struct { std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; + std::map, vk_pipeline> pipeline_fa_mask_opt; + vk_pipeline pipeline_flash_attn_split_k_reduce; vk_pipeline pipeline_count_experts; @@ -1549,6 +1552,18 @@ struct vk_op_flash_attn_split_k_reduce_push_constants { uint32_t sinks; }; +struct vk_op_flash_attn_mask_opt_push_constants { + uint32_t nem0; + uint32_t nem1; + uint32_t nem2; + uint32_t nbm1; + uint32_t nbm2; + uint32_t nbm3; + uint32_t nbd1; + uint32_t nbd2; + uint32_t nbd3; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -1757,6 +1772,7 @@ class vk_perf_logger { " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " << " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " << " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")"; + *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3]; return name.str(); } if (node->op == GGML_OP_TOP_K) { @@ -3177,7 +3193,7 @@ static void ggml_vk_load_shaders(vk_device& device) { return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. @@ -3209,7 +3225,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // AMD prefers loading K directly from global memory const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem}; + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt}; }; #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ @@ -3221,18 +3237,19 @@ static void ggml_vk_load_shaders(vk_device& device) { FaCodePath path = fa.first.path; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ + bool use_mask_opt = fa.first.use_mask_opt; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -4028,6 +4045,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + for (auto &it : device->pipeline_fa_mask_opt) { + auto BrBc = it.first; + ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size); + } + if (device->subgroup_clustered && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); } else { @@ -8400,8 +8422,6 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; - const uint32_t tmpsh = (Bc / MatBc) * sizeof(float); - const uint32_t qstride = hsk_pad / 4 + 2; const uint32_t Qf = Br * qstride * f16vec4; @@ -8418,7 +8438,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t slope = Br * acctype; - const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope; + const uint32_t total_size = Qf + Psh + sfsh + ksh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); @@ -8445,6 +8465,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const uint32_t nem0 = mask ? mask->ne[0] : 0; const uint32_t nem1 = mask ? mask->ne[1] : 0; const uint32_t nem2 = mask ? mask->ne[2] : 0; const uint32_t nem3 = mask ? mask->ne[3] : 0; @@ -8574,7 +8595,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc); + // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. + bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; + + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt); vk_pipeline pipeline = nullptr; @@ -8625,10 +8649,32 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_preallocate_buffers(ctx, subctx); } - { - // Request descriptor sets - if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache); + const uint32_t Br = rows_cols[0]; + const uint32_t Bc = rows_cols[1]; + + const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc); + const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3; + + vk_pipeline pipeline_fa_mask_opt = nullptr; + if (use_mask_opt) { + std::lock_guard guard(ctx->device->mutex); + auto &pipelines = ctx->device->pipeline_fa_mask_opt; + auto it = pipelines.find({Br, Bc}); + if (it != pipelines.end()) { + pipeline_fa_mask_opt = it->second; + } else { + pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared(); + } + assert(pipeline_fa_mask_opt); + ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1); + + if (ctx->prealloc_size_y < mask_opt_size) { + ctx->prealloc_size_y = mask_opt_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); } } @@ -8655,9 +8701,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf; vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf; + vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf; uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; + if (use_mask_opt) + { + const vk_op_flash_attn_mask_opt_push_constants opt_pc = { + nem0, + nem1, + nem2, + (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)), + (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)), + (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)), + mask_opt_num_dwords, + mask_opt_num_dwords * CEIL_DIV(nem1, Br), + mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2, + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt, + { mask_buf, mask_opt_buf }, opt_pc, + { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 }); + ggml_vk_sync_buffers(ctx, subctx); + } + const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, @@ -8672,13 +8739,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx gqa_ratio, split_kv, split_k }; if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } workgroups_x *= pipeline->wg_denoms[0]; vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf}, + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf}, // We only use split_k when group query attention is enabled, which means // there's no more than one tile of rows (i.e. workgroups_x would have been // one). We reuse workgroups_x to mean the number of splits, so we need to @@ -8697,7 +8766,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_x *= pipeline->wg_denoms[0]; } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf}, + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf}, pc, { workgroups_x, workgroups_y, workgroups_z }); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 3ce8d07be8..49a3c530cb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -94,6 +94,10 @@ void main() { } } + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + #if BLOCK_SIZE > 1 uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; @@ -104,15 +108,28 @@ void main() { uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - float max_mask = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; @@ -120,25 +137,12 @@ void main() { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); masksh[c][r] = m; - max_mask = max(max_mask, m); } else { masksh[c][r] = float(0); } } } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; - } } float Sf[Br][cols_per_thread]; @@ -185,7 +189,7 @@ void main() { } } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { float mvf = masksh[c * cols_per_iter + col_tid][r]; @@ -256,9 +260,6 @@ void main() { barrier(); } - // prevent race on tmpsh - barrier(); - // reduce across threads [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 23a4d2c005..252451101a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -10,6 +10,7 @@ layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; layout (constant_id = 7) const uint32_t SubGroupSize = 32; layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0; +layout (constant_id = 9) const bool USE_MASK_OPT = false; // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths const uint32_t HSK_pad = (HSK + 15) & ~15; @@ -66,6 +67,11 @@ layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; +layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; + +#define MASK_OPT_ALL_NEG_INF 1 +#define MASK_OPT_ALL_ZERO 2 + #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 #if defined(DATA_A_F32) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 83d52d19d6..89af3697e1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -42,8 +42,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -shared float tmpsh[row_split]; - const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; @@ -134,6 +132,10 @@ void main() { } } + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + #if BLOCK_SIZE > 1 uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; @@ -144,66 +146,74 @@ void main() { uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) { + mask_cache[idx] = f16vec4(0); + } - float max_mask = NEG_FLT_MAX_OVER_2; - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) / (Br / 4); - uint32_t r = (idx + tid) % (Br / 4); - if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { - if ((!KV_bounds_check || j * Bc + c < KV)) { - f16vec4 m; - if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) { - m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]); - max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3])); - } else if (i * Br + r * 4 + 2 < p.nem1) { - m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], - 0.0); - max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])); - } else if (i * Br + r * 4 + 1 < p.nem1) { - m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], - 0.0, - 0.0); - max_mask = max(max(max_mask, float(m[0])), float(m[1])); - } else if (i * Br + r * 4 < p.nem1) { - m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], - 0.0, - 0.0, - 0.0); - max_mask = max(max_mask, float(m[0])); - } else { - m = f16vec4(0.0); + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + float max_mask = NEG_FLT_MAX_OVER_2; + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + if ((!KV_bounds_check || j * Bc + c < KV)) { + f16vec4 m; + if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]); + max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3])); + } else if (i * Br + r * 4 + 2 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + 0.0); + max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])); + } else if (i * Br + r * 4 + 1 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + 0.0, + 0.0); + max_mask = max(max(max_mask, float(m[0])), float(m[1])); + } else if (i * Br + r * 4 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + 0.0, + 0.0, + 0.0); + max_mask = max(max_mask, float(m[0])); + } else { + m = f16vec4(0.0); + } + mask_cache[idx / WorkGroupSize] = m; } - mask_cache[idx / WorkGroupSize] = m; } } } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; - } } if (K_LOAD_SHMEM != 0) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 54f1b0b622..47b110621b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -138,48 +138,53 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - coopmat mv; + coopmat mv = coopmat(0); if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - if (nem1_bounds_check) { - tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); - tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - coopmat mvmax; + if (nem1_bounds_check) { + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t - coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + } else { + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); + // Don't clamp against nem1 when GQA is enabled + uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - // skip the block if the mask is entirely -inf - coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); - if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { - continue; - } - } else { - tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); - // Don't clamp against nem1 when GQA is enabled - uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; - tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); - tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - - coopmat mvmax; - - coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - - // skip the block if the mask is entirely -inf - coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); - if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { - continue; + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp new file mode 100644 index 0000000000..8c92c1adcd --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp @@ -0,0 +1,142 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint NUM_SUBGROUPS = 4; +layout (constant_id = 2) const uint Br = 32; +layout (constant_id = 3) const uint Bc = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float16_t data_a[];}; +layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];}; +layout (binding = 1) writeonly buffer D {uint data_d[];}; + +layout (push_constant) uniform parameter { + uint nem0; + uint nem1; + uint nem2; + uint nbm1; + uint nbm2; + uint nbm3; + uint nbd1; + uint nbd2; + uint nbd3; +}; + +#define MASK_OPT_ALL_NEG_INF 1 +#define MASK_OPT_ALL_ZERO 2 + +shared float minsh[NUM_SUBGROUPS]; +shared float maxsh[NUM_SUBGROUPS]; + +// For each Br x Bc block of the mask (input) buffer, read all values and check +// if it's all -inf or all zero. Write out a two-bit code indicating which it is +// (or zero for neither). Each workgroup processes 16 tiles and writes out a +// 32-bit result mask. +// +// TODO: This is a lot of work per workgroup, might make sense to split this into +// more workgroups in the future. +void main() { + // Each workgroup handles a row + const uint tid = gl_LocalInvocationIndex; + const uint i0 = gl_WorkGroupID.x; + const uint i1 = gl_WorkGroupID.y; + const uint i2 = gl_WorkGroupID.z % nem2; + const uint i3 = gl_WorkGroupID.z / nem2; + + float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF); + + uint result = 0; + + // Fast path for fully in-bounds blocks where we can do f16vec4 loads + if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 && + ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) { + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) { + uint j0 = (i + tid) % (Bc / 4); + uint j1 = (i + tid) / (Bc / 4); + + j0 *= 4; + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]); + [[unroll]] for (int c = 0; c < 4; ++c) { + min_v = min(min_v, f[c]); + max_v = max(max_v, f[c]); + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } + } else { + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) { + if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) { + continue; + } + uint j0 = (i + tid) % Bc; + uint j1 = (i + tid) / Bc; + + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + if (j0 < nem0 && j1 < nem1) { + float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]); + min_v = min(min_v, f); + max_v = max(max_v, f); + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } + } + + if (tid == 0) { + data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ca486a288a..42ebc21e2a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -790,6 +790,8 @@ void process_shaders() { string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + string_to_spv("fa_mask_opt", "flash_attn_mask_opt.comp", {}); + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}}); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index cecdf47038..fbe23037cc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -169,20 +169,22 @@ static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float m const int blck0 = 128; const int blck1 = 64; - // number of INF blocks - const int n_inf_blocks = 0.1*(ne0*ne1*ne2*ne3)/(blck0*blck1); + // number of INF/zero blocks + const int n_inf_zero_blocks = 0.2*(ne0*ne1*ne2*ne3)/(blck0*blck1); - for (int b = 0; b < n_inf_blocks; b++) { + for (int b = 0; b < n_inf_zero_blocks; b++) { const int p3 = (rd() % ne3); const int p2 = (rd() % ne2); const int p1 = (rd() % ne1); const int p0 = (rd() % ne0); + bool inf = rd() & 1; + for (int i1 = 0; i1 < blck1 && p1 + i1 < ne1; i1++) { const int idx = p3*ne2*ne1*ne0 + p2*ne1*ne0 + (p1 + i1)*ne0 + p0; for (int i0 = 0; i0 < blck0 && p0 + i0 < ne0; i0++) { - data_f32[idx + i0] = -INFINITY; + data_f32[idx + i0] = inf ? -INFINITY : 0.0f; } } }