From 88396c39232c2c93d36fccd232eda42d312fca77 Mon Sep 17 00:00:00 2001 From: Progeny Alpha Date: Sat, 14 Mar 2026 22:48:11 -0400 Subject: [PATCH] vulkan: optimize chunked intra kernel barrier and bank conflicts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove unnecessary barrier after A-matrix dot product writes. Each thread writes only to its own row; s_A isn't read cross-thread until forward substitution. Cuts A-matrix barriers from 128 to 65 (one per broadcast + one before forward sub). Pad s_A stride from 64 to 65 to eliminate bank conflicts in the W/U accumulation phase where all active threads read A(tid, j) with the same j value. GDN per-op: 5205 → 5136 µs. Combined with inter fusion: 6818 → 5136 µs (-24.7%). 16/16 tests pass. --- .../vulkan-shaders/gated_delta_net_chunk_intra.comp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp index 881fc98c22..eff8605e37 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp @@ -30,13 +30,14 @@ layout(binding = 5) writeonly buffer UBuf { float u_out[]; }; layout(binding = 6) writeonly buffer DecayBuf { float decay_out[]; }; layout(binding = 7) writeonly buffer GCumBuf { float gcum_out[]; }; // per-token g_cumsum -shared float s_A[CHUNK_SIZE * CHUNK_SIZE]; +const uint A_STRIDE = CHUNK_SIZE + 1; +shared float s_A[CHUNK_SIZE * A_STRIDE]; shared float s_decay[CHUNK_SIZE]; shared float s_beta[CHUNK_SIZE]; shared float s_k_broadcast[S_V]; shared float s_v_broadcast[S_V]; -#define A(i,j) s_A[(i) * CHUNK_SIZE + (j)] +#define A(i,j) s_A[(i) * A_STRIDE + (j)] void main() { const uint chunk_head = gl_WorkGroupID.x; @@ -115,8 +116,8 @@ void main() { float decay_factor = exp(s_decay[tid] - s_decay[j]); A(tid, j) = -s_beta[tid] * dot_kk * decay_factor; } - barrier(); } + barrier(); // Forward substitution: T = (I + A)^{-1}, in-place for (uint i = 1; i < chunk_len; i++) {