From e22c2b2c851addfe653c9766841178efc9c02731 Mon Sep 17 00:00:00 2001 From: Progeny Alpha Date: Sat, 14 Mar 2026 01:13:42 -0400 Subject: [PATCH] vulkan: clean up chunked GDN shaders for PR review Remove verbose algorithm comments, section dividers, stale inline constant annotations, and unused extensions. Match llama.cpp codebase style (minimal comments, no section decorators). No functional changes. 16/16 tests pass. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 -- .../gated_delta_net_chunk_inter.comp | 11 ---- .../gated_delta_net_chunk_intra.comp | 39 +------------ .../gated_delta_net_chunk_output_cm1.comp | 57 ++++--------------- 4 files changed, 15 insertions(+), 98 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d1c470b1c1..207d625b10 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10519,8 +10519,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s n_chunks, s_off }; - // Dispatch 1: Intra-chunk (parallel across chunks) - // Bindings: K, V, G, Beta, W_out, U_out, Decay_out, GCum_out ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra, {src_buf[1], src_buf[2], src_buf[3], src_buf[4], scratch_w, scratch_u, scratch_dec, scratch_gcum}, @@ -10528,8 +10526,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s ggml_vk_sync_buffers(ctx, subctx); - // Dispatch 2: Inter-chunk state propagation (sequential across chunks) - // Bindings: K, W, U, Decay, GCum, State, H_out, VNew_out, Final(dst) ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter, {src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum, src_buf[5], scratch_h, scratch_vnew, dst_buf}, @@ -10537,8 +10533,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s ggml_vk_sync_buffers(ctx, subctx); - // Dispatch 3: Output (parallel across chunks) - // Bindings: Q, K, H, VNew, GCum, Dst ggml_vk_dispatch_pipeline(ctx, subctx, pl_output, {src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf}, pc, { n_chunks * H, n_seqs, 1u }); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp index 0aa54e718f..8ebed83b20 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp @@ -3,17 +3,6 @@ #extension GL_EXT_control_flow_attributes : require // Inter-chunk state propagation for chunked gated delta net -// -// Sequential across chunks, parallel across state columns. -// For each chunk c: -// 1. Store state snapshot h[c] for output kernel -// 2. v_corrected = U - W @ S (C x d) -// 3. S_next = exp(g_total) * S + K_gated^T @ v_corrected (d x d) -// -// where K_gated[t] = k[t] * exp(g_total - g_cumsum[t]) -// -// Grid: (H, n_seqs, 1) -// Workgroup: S_V threads (one per state column) layout(constant_id = 0) const uint S_V = 128; layout(constant_id = 1) const uint CHUNK_SIZE = 64; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp index 8e1820ca85..881fc98c22 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp @@ -2,21 +2,7 @@ #extension GL_EXT_control_flow_attributes : require -// Intra-chunk kernel for chunked gated delta net (non-KDA scalar gate) -// -// For each chunk of C=64 tokens, computes W and U using the WY representation. -// Uses a single A matrix with gates, compensates W by multiplying k by exp(g). -// -// Algorithm (FLA-equivalent): -// 1. g_cumsum[j] = cumsum(g[0..j]) within chunk (log-space) -// 2. A[i][j] = -beta[i] * dot(k[i], k[j]) * exp(g_cumsum[i] - g_cumsum[j]) for i > j -// 3. T = (I + A)^{-1} via forward substitution (row-by-row parallel) -// 4. U[i] = sum_j T[i][j] * beta[j] * v[j] -// 5. W[i] = sum_j T[i][j] * beta[j] * exp(g_cumsum[j]) * k[j] -// 6. Output g_cumsum[C-1] as total chunk decay -// -// Grid: (n_chunks * H, n_seqs, 1) -// Workgroup: CHUNK_SIZE threads (one per token in chunk) +// Intra-chunk WY decomposition for chunked gated delta net layout(constant_id = 0) const uint S_V = 128; layout(constant_id = 1) const uint CHUNK_SIZE = 64; @@ -78,7 +64,6 @@ void main() { } barrier(); - // Step 1: Prefix sum of log-gates if (tid == 0) { for (uint i = 1; i < chunk_len; i++) { s_decay[i] += s_decay[i - 1]; @@ -86,13 +71,11 @@ void main() { } barrier(); - // Output per-token g_cumsum for inter-chunk and output kernels if (valid) { const uint gcum_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE; gcum_out[gcum_base + tid] = s_decay[tid]; } - // Load my k vector into registers float my_k[S_V]; if (valid) { const uint k_off = iq3 * sq3 + global_t * sq2 + iq1 * sq1; @@ -105,17 +88,12 @@ void main() { } } - // Step 2: Build A matrix using shared memory broadcast for k[j] - // A[i][j] = -beta[i] * dot(k[i], k[j]) * exp(g_cumsum[i] - g_cumsum[j]) for i > j - // Initialize to zero for (uint j = 0; j < CHUNK_SIZE; j++) { A(tid, j) = 0.0; } barrier(); - // For each column j, broadcast k[j] via shared memory, all threads compute their row for (uint j = 0; j < chunk_len; j++) { - // Broadcast k[j] — need multiple passes when S_V > CHUNK_SIZE { const uint j_global = chunk_start + j; const uint kj_off = iq3 * sq3 + j_global * sq2 + iq1 * sq1; @@ -140,12 +118,8 @@ void main() { barrier(); } - // Step 3: Forward substitution T = (I + A)^{-1} - // Process row by row. For row i, all threads j < i compute in parallel: - // T[i][j] += sum_m T[i][m] * T[m][j] for m in [j..i-1] - // The A matrix is modified in-place to become T. + // Forward substitution: T = (I + A)^{-1}, in-place for (uint i = 1; i < chunk_len; i++) { - // Each thread with tid < i computes T[i][tid] if (tid < i) { float sum = 0.0; for (uint m = tid; m < i; m++) { @@ -156,19 +130,12 @@ void main() { barrier(); } - // Add identity if (valid) { A(tid, tid) = 1.0; } barrier(); - // Step 4+5: Compute W and U via shared memory broadcast + register accumulation - // U[tid][d] = sum_j T[tid][j] * beta[j] * v[j][d] - // W[tid][d] = sum_j T[tid][j] * beta[j] * exp(g_cumsum[j]) * k[j][d] - // - // For each j, broadcast k[j]*exp(g[j]) and v[j] via shared memory. - // Accumulate in d-tiles of 32 to limit register pressure. - + // W and U via tiled broadcast accumulation const uint out_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V; const uint TILE_D = 32; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp index 533f5f4730..31fba4522a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp @@ -1,27 +1,13 @@ #version 450 #extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_KHR_shader_subgroup_basic : enable -#extension GL_KHR_shader_subgroup_arithmetic : enable #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #include "types.glsl" -// Coopmat output kernel for chunked gated delta net -// -// Phase 1: A = Q @ K^T (coopmat, f16→f32) -// Phase 2: Decay mask → sh_adecay (f16) + vnew → sh_kv (f16, pre-scaled) -// Pass 1: O_inter = Q @ S → dst (scalar, 128 threads) -// Pass 2: O_intra = A_decayed @ vnew → dst (coopmat GEMM, full chunks) -// Partial last chunk: scalar fallback. 3 barriers total. -// -// Grid: (n_chunks * H, n_seqs, 1) -// Workgroup: WG_SIZE threads = 4 subgroups - layout(constant_id = 0) const uint WG_SIZE = 256; layout(constant_id = 1) const uint CHUNK_SIZE = 64; layout(constant_id = 2) const uint S_V = 128; @@ -51,16 +37,16 @@ const uint TM = 16; const uint TN = 16; const uint TK = 16; -const uint C_TILES = CHUNK_SIZE / TM; // 4 -const uint D_TILES = S_V / TN; // 8 +const uint C_TILES = CHUNK_SIZE / TM; +const uint D_TILES = S_V / TN; // Shared memory strides in f16vec4 units, padded for bank conflicts -const uint QK_STRIDE = S_V / 4 + 2; // 34 -const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2; // 18 +const uint QK_STRIDE = S_V / 4 + 2; +const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2; -shared f16vec4 sh_q[CHUNK_SIZE * QK_STRIDE]; // Q (f16) for coopmat Phase 1 -shared f16vec4 sh_kv[CHUNK_SIZE * QK_STRIDE]; // K (f16) for coopmat Phase 1 -shared vec4 sh_attn[CHUNK_SIZE * ATTN_V4_STRIDE]; // attention matrix (f32) +shared f16vec4 sh_q[CHUNK_SIZE * QK_STRIDE]; +shared f16vec4 sh_kv[CHUNK_SIZE * QK_STRIDE]; +shared vec4 sh_attn[CHUNK_SIZE * ATTN_V4_STRIDE]; shared f16vec4 sh_adecay[CHUNK_SIZE * ATTN_V4_STRIDE]; shared float sh_gcum[CHUNK_SIZE]; @@ -80,10 +66,6 @@ void main() { const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start); const float scale = 1.0 / sqrt(float(S_V)); - // ================================================================ - // Load Q, K, gcum to shared memory - // ================================================================ - if (tid < CHUNK_SIZE) { const uint gcum_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE; sh_gcum[tid] = (tid < chunk_len) ? gcum_in[gcum_base + tid] : 0.0; @@ -105,10 +87,7 @@ void main() { barrier(); - // ================================================================ - // Phase 1: A = Q @ K^T [C×D] × [D×C] → [C×C] (coopmat) - // ================================================================ - + // A = Q @ K^T (coopmat) coopmat A_acc[C_TILES]; [[unroll]] for (uint tj = 0; tj < C_TILES; tj++) { A_acc[tj] = coopmat(0.0); @@ -133,7 +112,6 @@ void main() { } } - // Store A to sh_attn as f32 [[unroll]] for (uint tj = 0; tj < C_TILES; tj++) { coopMatStore(A_acc[tj], sh_attn, sg_id * TM * ATTN_V4_STRIDE + tj * (TN / 4), @@ -142,16 +120,10 @@ void main() { barrier(); - // ================================================================ - // Phase 2: Decay mask + vnew load (all 256 threads) - // Pass 1: Inter-chunk Q@S → dst (128 active threads) - // No shared-memory write conflicts — runs without intermediate barrier. - // ================================================================ - const uint wu_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V; const uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; - // Phase 2a: A_decayed = causal_decay_mask(A) → sh_adecay (f16) + // Causal decay mask (f16) for (uint idx = tid; idx < CHUNK_SIZE * (CHUNK_SIZE / 4); idx += WG_SIZE) { const uint t = idx / (CHUNK_SIZE / 4); const uint j4 = idx % (CHUNK_SIZE / 4); @@ -170,7 +142,7 @@ void main() { sh_adecay[t * ATTN_V4_STRIDE + j4] = val; } - // Phase 2b: vnew → sh_kv (f16, pre-scaled by 1/√S_V) + // vnew to f16, pre-scaled for (uint idx = tid; idx < CHUNK_SIZE * (S_V / 4); idx += WG_SIZE) { const uint row = idx / (S_V / 4); const uint col4 = idx % (S_V / 4); @@ -186,7 +158,7 @@ void main() { sh_kv[row * QK_STRIDE + col4] = val; } - // Pass 1: Inter-chunk (128 active threads, write directly to dst) + // O_inter = Q @ state { const uint col = tid; const bool col_active = (col < S_V); @@ -216,12 +188,7 @@ void main() { barrier(); - // ================================================================ - // Pass 2: Intra-chunk A_decayed[C×C] @ vnew[C×S_V] → [C×S_V] - // Full chunks: coopmat GEMM (accumulates onto inter results in dst). - // Partial (last) chunk: scalar fallback. - // ================================================================ - + // O_intra = A_decayed @ vnew (coopmat GEMM, scalar fallback for partial chunks) if (chunk_len == CHUNK_SIZE) { coopmat A_mat; coopmat V_mat;