diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index b7d232eb66..af0b371cfa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -19,7 +19,9 @@ const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; -const uint32_t cols_per_iter = WorkGroupSize / D_split; +const uint32_t row_split = 4; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; @@ -44,7 +46,9 @@ shared float tmpsh[WorkGroupSize]; shared vec4 tmpshv4[WorkGroupSize]; shared FLOAT_TYPE masksh[Bc][Br]; -shared FLOAT_TYPEV4 Qf[Br][HSK / 4]; + +const uint qfstride = HSK / 4 + 1; +shared FLOAT_TYPEV4 Qf[Br * qfstride]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -54,8 +58,12 @@ void main() { init_indices(); const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; const uint32_t d_tid = gl_LocalInvocationIndex % D_split; - const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + +#define tile_row(r) (row_tid * rows_per_thread + (r)) uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4; @@ -64,37 +72,37 @@ void main() { uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r][d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + Qf[r * qfstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - ACC_TYPEV4 Of[Br][HSV_per_thread / 4]; + ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] = ACC_TYPEV4(0.0); } } - float Lf[Br], Mf[Br]; + float Lf[rows_per_thread], Mf[rows_per_thread]; // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lf[r] = 0; Mf[r] = NEG_FLT_MAX_OVER_2; } - ACC_TYPE slope[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + ACC_TYPE slope[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { slope[r] = ACC_TYPE(1.0); } // ALiBi if (p.max_bias > 0.0f) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2); } } @@ -163,8 +171,8 @@ void main() { } } - ACC_TYPE Sf[Br][cols_per_thread]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + ACC_TYPE Sf[rows_per_thread][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { Sf[r][c] = ACC_TYPE(0.0); } @@ -184,8 +192,8 @@ void main() { #else FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Sf[r][c] += ACC_TYPE(dot(Qf[r][d * D_split + d_tid], K_Tf)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qfstride + d * D_split + d_tid], K_Tf)); } } } @@ -193,14 +201,14 @@ void main() { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { // Compute sum across the D_split [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); } } } if (LOGIT_SOFTCAP) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c])); } @@ -209,8 +217,8 @@ void main() { if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - FLOAT_TYPE mvf = masksh[c * cols_per_iter + col_tid][r]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE mvf = masksh[c * cols_per_iter + col_tid][tile_row(r)]; Sf[r][c] += slope[r]*mvf; } @@ -218,9 +226,9 @@ void main() { barrier(); } - FLOAT_TYPE Pf[Br][cols_per_thread]; - float eMf[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + FLOAT_TYPE Pf[rows_per_thread][cols_per_thread]; + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { @@ -252,7 +260,7 @@ void main() { } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; } } @@ -270,7 +278,7 @@ void main() { #else FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] += ACC_TYPE(Pf[r][c] * Vf); } } @@ -284,7 +292,7 @@ void main() { // reduce across threads - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { float rowmaxf, eMf; tmpsh[tid] = Mf[r]; @@ -346,21 +354,23 @@ void main() { // note: O and Q have swapped coord 1,2. uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + perElemOpGqaStore(row, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); } } } } o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { - perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); } } @@ -368,8 +378,8 @@ void main() { } if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); float ms = 1.0f; float vs = 1.0f; @@ -388,13 +398,13 @@ void main() { } } - float Lfrcp[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] *= ACC_TYPE(Lfrcp[r]); #if defined(ACC_TYPE_MAX) Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); @@ -405,21 +415,23 @@ void main() { uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + perElemOpGqaStore(row, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); } } } } } else { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (i * Br + r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (i * Br + row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + data_o[o_offset + iq2 * HSV + (i * Br + row) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); } } }