diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp index 881fc98c22..eff8605e37 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp @@ -30,13 +30,14 @@ layout(binding = 5) writeonly buffer UBuf { float u_out[]; }; layout(binding = 6) writeonly buffer DecayBuf { float decay_out[]; }; layout(binding = 7) writeonly buffer GCumBuf { float gcum_out[]; }; // per-token g_cumsum -shared float s_A[CHUNK_SIZE * CHUNK_SIZE]; +const uint A_STRIDE = CHUNK_SIZE + 1; +shared float s_A[CHUNK_SIZE * A_STRIDE]; shared float s_decay[CHUNK_SIZE]; shared float s_beta[CHUNK_SIZE]; shared float s_k_broadcast[S_V]; shared float s_v_broadcast[S_V]; -#define A(i,j) s_A[(i) * CHUNK_SIZE + (j)] +#define A(i,j) s_A[(i) * A_STRIDE + (j)] void main() { const uint chunk_head = gl_WorkGroupID.x; @@ -115,8 +116,8 @@ void main() { float decay_factor = exp(s_decay[tid] - s_decay[j]); A(tid, j) = -s_beta[tid] * dot_kk * decay_factor; } - barrier(); } + barrier(); // Forward substitution: T = (I + A)^{-1}, in-place for (uint i = 1; i < chunk_len; i++) {