From 530e5bb117cd2b27c8aebc051fe0aac92df2cb8c Mon Sep 17 00:00:00 2001 From: Progeny Alpha Date: Sat, 14 Mar 2026 22:32:46 -0400 Subject: [PATCH] vulkan: fuse w/k_gated broadcasts in chunked inter kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Load both s_w and s_kg before the first barrier instead of using separate barriers for each. Reduces per-token barriers from 3 to 2, eliminating 64 barriers per chunk. GDN per-op: 6818 → 5205 µs (-23.6%). 16/16 tests pass. --- .../gated_delta_net_chunk_inter.comp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp index 8ebed83b20..2ed8260c76 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp @@ -74,6 +74,12 @@ void main() { for (uint t = 0; t < chunk_len; t++) { s_w[col] = w_in[wu_base + t * S_V + col]; + + const float g_cumsum_t = gcum_in[gcum_base + t]; + const float decay_factor = exp(g_total - g_cumsum_t); + const uint t_global = chunk_start + t; + const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1; + s_kg[col] = k_in[k_off + col] * decay_factor; barrier(); float ws = 0.0; @@ -87,15 +93,6 @@ void main() { float vnew = u_in[wu_base + t * S_V + col] - ws; vnew_out[wu_base + t * S_V + col] = vnew; - // K_gated[t] = k[t] * exp(g_total - g_cumsum[t]) - float g_cumsum_t = gcum_in[gcum_base + t]; - float decay_factor = exp(g_total - g_cumsum_t); - - const uint t_global = chunk_start + t; - const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1; - s_kg[col] = k_in[k_off + col] * decay_factor; - barrier(); - [[unroll]] for (uint i = 0; i < S_V; i++) { delta[i] += s_kg[i] * vnew; }