From 33f890e5799d1bb539f72f3d7e1ea30b8cff1123 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 21 Jan 2026 10:43:43 -0600 Subject: [PATCH] vulkan: support flash attention GQA/split_k with small batches (#18938) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 42 +++++++++++++------ .../vulkan-shaders/flash_attn.comp | 13 +++--- .../vulkan-shaders/flash_attn_base.glsl | 17 +++++--- .../vulkan-shaders/flash_attn_cm1.comp | 13 +++--- .../vulkan-shaders/flash_attn_cm2.comp | 13 +++--- .../flash_attn_split_k_reduce.comp | 17 ++++---- tests/test-backend-ops.cpp | 3 ++ 7 files changed, 74 insertions(+), 44 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 62a878556b..739361e777 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1518,6 +1518,15 @@ struct vk_quantize_q8_1_push_constants { uint32_t num_blocks; }; +struct vk_op_flash_attn_split_k_reduce_push_constants { + uint32_t D; + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + uint32_t k_num; + uint32_t sinks; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -3982,7 +3991,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); if (device->subgroup_clustered && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); @@ -8457,14 +8466,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(0); } - if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 // and change addressing calculations to index Q's dimension 2. gqa_ratio = qk_ratio; N = gqa_ratio; - workgroups_y /= N; + workgroups_y /= gqa_ratio; } bool small_rows = N <= get_fa_num_small_rows(path); @@ -8526,6 +8535,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } assert(pipeline); + // Compile early to initialize wg_denoms. + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); uint32_t split_kv = KV; uint32_t split_k = 1; @@ -8533,22 +8544,24 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Use a placeholder core count if one isn't available. split_k is a big help for perf. const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; - // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && shader_core_count > 0) { + // Try to use split_k when KV is large enough to be worth the overhead. + // Must either be a single batch or be using gqa, we can't mix the two. + if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) { // Try to run two workgroups per SM. - split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); + split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z); if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); split_k = CEIL_DIV(KV, split_kv); - workgroups_x = split_k; } } // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. - const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; + // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3]. + // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3]. + const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0; if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) { GGML_ABORT("Requested preallocation size is too large"); } @@ -8559,7 +8572,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); } @@ -8608,7 +8620,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - + workgroups_x *= pipeline->wg_denoms[0]; vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf}, @@ -8616,15 +8628,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // there's no more than one tile of rows (i.e. workgroups_x would have been // one). We reuse workgroups_x to mean the number of splits, so we need to // cancel out the divide by wg_denoms[0]. - pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + pc, { split_k * workgroups_x, workgroups_y, workgroups_z }); ggml_vk_sync_buffers(ctx, subctx); - const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; + const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, {split_k_buf, sinks_buf, dst_buf}, - pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); + pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) }); ctx->prealloc_split_k_need_sync = true; } else { + if (gqa_ratio > 1) { + // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms + workgroups_x *= pipeline->wg_denoms[0]; + } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf}, pc, { workgroups_x, workgroups_y, workgroups_z }); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0379e5d502..3ce8d07be8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -53,7 +53,7 @@ void main() { const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = gl_LocalInvocationIndex / D_split; - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); @@ -101,9 +101,9 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -320,7 +320,8 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { @@ -332,7 +333,7 @@ void main() { } } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -378,7 +379,7 @@ void main() { } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index eb93903c46..29b5c7c3a4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -165,7 +165,7 @@ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC } uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, - iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, q_stride, k_stride, v_stride, m_stride; void init_indices() @@ -173,12 +173,19 @@ void init_indices() N = p.N; KV = p.KV; - i = gl_WorkGroupID.x; - split_k_index = 0; - if (p.k_num > 1) { i = 0; - split_k_index = gl_WorkGroupID.x; + // batch and split_k share gl_WorkGroupID.x + gqa_iq1 = gl_WorkGroupID.x / p.k_num; + split_k_index = gl_WorkGroupID.x % p.k_num; + } else if (p.gqa_ratio > 1) { + i = 0; + gqa_iq1 = gl_WorkGroupID.x; + split_k_index = 0; + } else { + i = gl_WorkGroupID.x; + gqa_iq1 = 0; + split_k_index = 0; } Tr = CEIL_DIV(N, Br); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index c995ab140e..0eb50fe58f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -90,7 +90,7 @@ void main() { barrier(); } - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); @@ -141,9 +141,9 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -370,7 +370,8 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { @@ -382,7 +383,7 @@ void main() { } } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -428,7 +429,7 @@ void main() { } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 9a71996383..d49a8da65f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -111,7 +111,7 @@ void main() { coopmat Q; coopmat Qf16; - uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03; coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); Qf16 = coopmat(Q); @@ -138,9 +138,9 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; } [[dont_unroll]] @@ -272,10 +272,11 @@ void main() { if (p.k_num > 1) { coopmat O_D = coopmat(O); - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); return; @@ -325,7 +326,7 @@ void main() { [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } #endif - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 4eaddd31a8..68917fc0bb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -12,7 +12,8 @@ layout (binding = 2) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; - uint N; + uint ne1; + uint ne2; uint ne3; uint k_num; uint sinks; @@ -24,15 +25,15 @@ void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; - const uint iq3 = gl_WorkGroupID.z; + const uint i2 = gl_WorkGroupID.z % p.ne2; + const uint i3 = gl_WorkGroupID.z / p.ne2; uint D = p.D; - uint N = p.N; uint k_num = p.k_num; - uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; - uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; - uint lm_stride = N * 2; + uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n; + uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n; + uint lm_stride = p.ne1 * 2; // Compute the max m value for the row float m_max = -1.0/0.0; @@ -99,7 +100,7 @@ void main() { if (d < D) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { - uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; + uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d; float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } @@ -115,6 +116,6 @@ void main() { const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); O = clamp(O, -FLT_MAX, FLT_MAX); - data_d[iq3 * D * N + D * n + d] = O; + data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O; } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6bb781737e..9f61c6483d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8460,6 +8460,9 @@ static std::vector> make_test_cases_perf() { // Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012 test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 4, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) {