relax flash attention split_k condition to allow non-gqa use

This commit is contained in:
Ruben Ortlam 2026-02-10 19:50:24 +01:00
parent d6a004547f
commit 9f9b701ff5
5 changed files with 150 additions and 60 deletions

View File

@ -8668,19 +8668,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache);
const uint32_t Br = rows_cols[0];
const uint32_t Bc = rows_cols[1];
GGML_ASSERT(Br == pipeline->wg_denoms[0]);
const uint32_t Tr = CEIL_DIV(N, Br);
// Try to use split_k when KV is large enough to be worth the overhead.
// Must either be a single batch or be using gqa, we can't mix the two.
if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) {
// Try to run two workgroups per SM.
if (gqa_ratio > 1 && workgroups_x <= Br) {
split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
split_k = CEIL_DIV(KV, split_kv);
} else if (gqa_ratio <= 1) {
uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z;
if (total_wgs_no_split < shader_core_count * 2) {
split_k = shader_core_count * 2 / total_wgs_no_split;
}
}
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
split_k = CEIL_DIV(KV, split_kv);
}
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
// For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
@ -8694,10 +8705,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_preallocate_buffers(ctx, subctx);
}
auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache);
const uint32_t Br = rows_cols[0];
const uint32_t Bc = rows_cols[1];
const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
@ -8777,15 +8784,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
workgroups_x *= pipeline->wg_denoms[0];
// We reuse workgroups_x to mean the number of splits, so we need to
// cancel out the divide by wg_denoms[0].
uint32_t dispatch_x;
if (gqa_ratio > 1) {
workgroups_x *= pipeline->wg_denoms[0];
dispatch_x = split_k * workgroups_x;
} else {
dispatch_x = Tr * split_k * pipeline->wg_denoms[0];
}
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
// We only use split_k when group query attention is enabled, which means
// there's no more than one tile of rows (i.e. workgroups_x would have been
// one). We reuse workgroups_x to mean the number of splits, so we need to
// cancel out the divide by wg_denoms[0].
pc, { split_k * workgroups_x, workgroups_y, workgroups_z });
pc, { dispatch_x, workgroups_y, workgroups_z });
ggml_vk_sync_buffers(ctx, subctx);
const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };

View File

@ -457,27 +457,47 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
if (p.gqa_ratio > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
}
}
}
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < N) {
perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
const uint global_row = i * Br + row;
if (global_row < N) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
}
}
if (global_row < N && d_tid == 0 && col_tid == 0) {
uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
}
}
}
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < N) {
perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
return;
}

View File

@ -192,10 +192,16 @@ void init_indices()
KV = p.KV;
if (p.k_num > 1) {
i = 0;
// batch and split_k share gl_WorkGroupID.x
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
split_k_index = gl_WorkGroupID.x % p.k_num;
if (p.gqa_ratio > 1) {
i = 0;
// batch and split_k share gl_WorkGroupID.x
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
split_k_index = gl_WorkGroupID.x % p.k_num;
} else {
gqa_iq1 = 0;
split_k_index = gl_WorkGroupID.x % p.k_num;
i = gl_WorkGroupID.x / p.k_num;
}
} else if (p.gqa_ratio > 1) {
i = 0;
gqa_iq1 = gl_WorkGroupID.x;

View File

@ -525,25 +525,48 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
if (p.gqa_ratio > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV/4) break;
const uint d_local = d0 / threads_per_rowgroup;
gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV/4) break;
const uint d_local = d0 / threads_per_rowgroup;
gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
}
}
}
}
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
const uint global_row = i * Br + row;
if (global_row < N) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV/4) break;
data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]);
}
}
if (global_row < N && col_tid == 0) {
uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
}
}
}

View File

@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem;
}
// Store O values for non-GQA split_k. Rows are tokens, not heads.
D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) {
uint32_t global_row = i * Br + r;
if (global_row < N && c < HSV) {
uint32_t o_off = HSV * p.ne1
* (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[o_off + iq2 * HSV + c] = D_TYPE(elem);
}
return elem;
}
// Store L/M values for non-GQA split_k.
ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) {
uint32_t global_row = i * Br + r;
if (global_row < N && c == 0) {
uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3
+ p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[lm_off + lm_base + iq2] = D_TYPE(elem);
}
return elem;
}
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
@ -290,13 +312,19 @@ void main() {
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
if (p.gqa_ratio > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
} else {
coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N);
coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N);
coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N);
}
return;
}