diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp index 2ed8260c76..ed41987c2f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp @@ -105,7 +105,6 @@ void main() { } } - // Write final state to dst at s_off [[unroll]] for (uint i = 0; i < S_V; i++) { final_out[s_off + state_base + col * S_V + i] = state[i]; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp index eff8605e37..5afa5af19c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp @@ -2,8 +2,6 @@ #extension GL_EXT_control_flow_attributes : require -// Intra-chunk WY decomposition for chunked gated delta net - layout(constant_id = 0) const uint S_V = 128; layout(constant_id = 1) const uint CHUNK_SIZE = 64; @@ -28,7 +26,7 @@ layout(binding = 3) readonly buffer BetaBuf { float beta_in[]; }; layout(binding = 4) writeonly buffer WBuf { float w_out[]; }; layout(binding = 5) writeonly buffer UBuf { float u_out[]; }; layout(binding = 6) writeonly buffer DecayBuf { float decay_out[]; }; -layout(binding = 7) writeonly buffer GCumBuf { float gcum_out[]; }; // per-token g_cumsum +layout(binding = 7) writeonly buffer GCumBuf { float gcum_out[]; }; const uint A_STRIDE = CHUNK_SIZE + 1; shared float s_A[CHUNK_SIZE * A_STRIDE]; @@ -54,7 +52,6 @@ void main() { const uint global_t = chunk_start + tid; const bool valid = tid < chunk_len; - // Load beta and gate if (valid) { const uint gb_off = seq_id * sb3 + global_t * sb2 + head_id * sb1; s_beta[tid] = beta_in[gb_off]; @@ -89,7 +86,7 @@ void main() { } } - for (uint j = 0; j < CHUNK_SIZE; j++) { + [[unroll]] for (uint j = 0; j < CHUNK_SIZE; j++) { A(tid, j) = 0.0; } barrier(); @@ -119,7 +116,6 @@ void main() { } barrier(); - // Forward substitution: T = (I + A)^{-1}, in-place for (uint i = 1; i < chunk_len; i++) { if (tid < i) { float sum = 0.0; @@ -136,14 +132,13 @@ void main() { } barrier(); - // W and U via tiled broadcast accumulation const uint out_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V; const uint TILE_D = 32; for (uint d_start = 0; d_start < S_V; d_start += TILE_D) { float my_w[TILE_D]; float my_u[TILE_D]; - for (uint d = 0; d < TILE_D; d++) { + [[unroll]] for (uint d = 0; d < TILE_D; d++) { my_w[d] = 0.0; my_u[d] = 0.0; } @@ -154,7 +149,6 @@ void main() { const uint vj_off = seq_id * sv3 + j_global * sv2 + head_id * sv1; float eg = exp(s_decay[j]); - // Broadcast tile of k[j] and v[j] for (uint d = tid; d < S_V; d += CHUNK_SIZE) { if (d >= d_start && d < d_start + TILE_D) { s_k_broadcast[d] = k_in[kj_off + d] * eg; @@ -173,17 +167,15 @@ void main() { barrier(); } - // Write tile to global memory if (valid) { - for (uint d = 0; d < TILE_D; d++) { + [[unroll]] for (uint d = 0; d < TILE_D; d++) { w_out[out_base + tid * S_V + d_start + d] = my_w[d]; u_out[out_base + tid * S_V + d_start + d] = my_u[d]; } } } - // Output total chunk decay - if (tid == 0) { + if (tid == 0 && chunk_len > 0) { const uint decay_idx = (seq_id * n_chunks + chunk_id) * H + head_id; decay_out[decay_idx] = s_decay[chunk_len - 1]; }