diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 69345be866..861e17512e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8668,19 +8668,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Use a placeholder core count if one isn't available. split_k is a big help for perf. const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache); + const uint32_t Br = rows_cols[0]; + const uint32_t Bc = rows_cols[1]; + + GGML_ASSERT(Br == pipeline->wg_denoms[0]); + const uint32_t Tr = CEIL_DIV(N, Br); + // Try to use split_k when KV is large enough to be worth the overhead. - // Must either be a single batch or be using gqa, we can't mix the two. - if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) { - // Try to run two workgroups per SM. + if (gqa_ratio > 1 && workgroups_x <= Br) { split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z); - if (split_k > 1) { - // Try to evenly split KV into split_k chunks, but it needs to be a multiple - // of "align", so recompute split_k based on that. - split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); - split_k = CEIL_DIV(KV, split_kv); + } else if (gqa_ratio <= 1) { + uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z; + if (total_wgs_no_split < shader_core_count * 2) { + split_k = shader_core_count * 2 / total_wgs_no_split; } } + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); + split_k = CEIL_DIV(KV, split_kv); + } + // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3]. @@ -8694,10 +8705,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_preallocate_buffers(ctx, subctx); } - auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache); - const uint32_t Br = rows_cols[0]; - const uint32_t Bc = rows_cols[1]; - const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc); const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3; @@ -8777,15 +8784,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - workgroups_x *= pipeline->wg_denoms[0]; + + // We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + uint32_t dispatch_x; + if (gqa_ratio > 1) { + workgroups_x *= pipeline->wg_denoms[0]; + dispatch_x = split_k * workgroups_x; + } else { + dispatch_x = Tr * split_k * pipeline->wg_denoms[0]; + } + vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf}, - // We only use split_k when group query attention is enabled, which means - // there's no more than one tile of rows (i.e. workgroups_x would have been - // one). We reuse workgroups_x to mean the number of splits, so we need to - // cancel out the divide by wg_denoms[0]. - pc, { split_k * workgroups_x, workgroups_y, workgroups_z }); + pc, { dispatch_x, workgroups_y, workgroups_z }); ggml_vk_sync_buffers(ctx, subctx); const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0a425ce75f..563a2bcbdc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -457,27 +457,47 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - const uint row = tile_row(r); - if (row < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); + } + } + } + + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + const uint global_row = i * Br + row; + + if (global_row < N) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4; + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]); + } + } + + if (global_row < N && d_tid == 0 && col_tid == 0) { + uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_offset + iq2] = D_TYPE(Lf[r]); + data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]); } } } - - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - const uint row = tile_row(r); - if (row < N) { - perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); - } - } - return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 88b6d65bee..40f529dc23 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -192,10 +192,16 @@ void init_indices() KV = p.KV; if (p.k_num > 1) { - i = 0; - // batch and split_k share gl_WorkGroupID.x - gqa_iq1 = gl_WorkGroupID.x / p.k_num; - split_k_index = gl_WorkGroupID.x % p.k_num; + if (p.gqa_ratio > 1) { + i = 0; + // batch and split_k share gl_WorkGroupID.x + gqa_iq1 = gl_WorkGroupID.x / p.k_num; + split_k_index = gl_WorkGroupID.x % p.k_num; + } else { + gqa_iq1 = 0; + split_k_index = gl_WorkGroupID.x % p.k_num; + i = gl_WorkGroupID.x / p.k_num; + } } else if (p.gqa_ratio > 1) { i = 0; gqa_iq1 = gl_WorkGroupID.x; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index ebb55c9504..e35d4666cd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -525,25 +525,48 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - if (tile_row(r) < N) { - [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { - const uint d = d0 + col_tid; - if (d >= HSV/4) break; - const uint d_local = d0 / threads_per_rowgroup; - gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + const uint d_local = d0 / threads_per_rowgroup; + gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); + } } } - } - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - if (tile_row(r) < N) { - perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + const uint global_row = i * Br + row; + + if (global_row < N) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4; + + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]); + } + } + + if (global_row < N && col_tid == 0) { + uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_offset + iq2] = D_TYPE(Lf[r]); + data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]); + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 853f17fa16..0ea181342c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } +// Store O values for non-GQA split_k. Rows are tokens, not heads. +D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) { + uint32_t global_row = i * Br + r; + if (global_row < N && c < HSV) { + uint32_t o_off = HSV * p.ne1 + * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[o_off + iq2 * HSV + c] = D_TYPE(elem); + } + return elem; +} + +// Store L/M values for non-GQA split_k. +ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) { + uint32_t global_row = i * Br + r; + if (global_row < N && c == 0) { + uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_off + lm_base + iq2] = D_TYPE(elem); + } + return elem; +} + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -290,13 +312,19 @@ void main() { if (p.k_num > 1) { coopmat O_D = coopmat(O); - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); - coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + } else { + coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N); + coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N); + coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N); + } return; }