diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d1c470b1c1..207d625b10 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10519,8 +10519,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s n_chunks, s_off }; - // Dispatch 1: Intra-chunk (parallel across chunks) - // Bindings: K, V, G, Beta, W_out, U_out, Decay_out, GCum_out ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra, {src_buf[1], src_buf[2], src_buf[3], src_buf[4], scratch_w, scratch_u, scratch_dec, scratch_gcum}, @@ -10528,8 +10526,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s ggml_vk_sync_buffers(ctx, subctx); - // Dispatch 2: Inter-chunk state propagation (sequential across chunks) - // Bindings: K, W, U, Decay, GCum, State, H_out, VNew_out, Final(dst) ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter, {src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum, src_buf[5], scratch_h, scratch_vnew, dst_buf}, @@ -10537,8 +10533,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s ggml_vk_sync_buffers(ctx, subctx); - // Dispatch 3: Output (parallel across chunks) - // Bindings: Q, K, H, VNew, GCum, Dst ggml_vk_dispatch_pipeline(ctx, subctx, pl_output, {src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf}, pc, { n_chunks * H, n_seqs, 1u }); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp index 0aa54e718f..8ebed83b20 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp @@ -3,17 +3,6 @@ #extension GL_EXT_control_flow_attributes : require // Inter-chunk state propagation for chunked gated delta net -// -// Sequential across chunks, parallel across state columns. -// For each chunk c: -// 1. Store state snapshot h[c] for output kernel -// 2. v_corrected = U - W @ S (C x d) -// 3. S_next = exp(g_total) * S + K_gated^T @ v_corrected (d x d) -// -// where K_gated[t] = k[t] * exp(g_total - g_cumsum[t]) -// -// Grid: (H, n_seqs, 1) -// Workgroup: S_V threads (one per state column) layout(constant_id = 0) const uint S_V = 128; layout(constant_id = 1) const uint CHUNK_SIZE = 64; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp index 8e1820ca85..881fc98c22 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp @@ -2,21 +2,7 @@ #extension GL_EXT_control_flow_attributes : require -// Intra-chunk kernel for chunked gated delta net (non-KDA scalar gate) -// -// For each chunk of C=64 tokens, computes W and U using the WY representation. -// Uses a single A matrix with gates, compensates W by multiplying k by exp(g). -// -// Algorithm (FLA-equivalent): -// 1. g_cumsum[j] = cumsum(g[0..j]) within chunk (log-space) -// 2. A[i][j] = -beta[i] * dot(k[i], k[j]) * exp(g_cumsum[i] - g_cumsum[j]) for i > j -// 3. T = (I + A)^{-1} via forward substitution (row-by-row parallel) -// 4. U[i] = sum_j T[i][j] * beta[j] * v[j] -// 5. W[i] = sum_j T[i][j] * beta[j] * exp(g_cumsum[j]) * k[j] -// 6. Output g_cumsum[C-1] as total chunk decay -// -// Grid: (n_chunks * H, n_seqs, 1) -// Workgroup: CHUNK_SIZE threads (one per token in chunk) +// Intra-chunk WY decomposition for chunked gated delta net layout(constant_id = 0) const uint S_V = 128; layout(constant_id = 1) const uint CHUNK_SIZE = 64; @@ -78,7 +64,6 @@ void main() { } barrier(); - // Step 1: Prefix sum of log-gates if (tid == 0) { for (uint i = 1; i < chunk_len; i++) { s_decay[i] += s_decay[i - 1]; @@ -86,13 +71,11 @@ void main() { } barrier(); - // Output per-token g_cumsum for inter-chunk and output kernels if (valid) { const uint gcum_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE; gcum_out[gcum_base + tid] = s_decay[tid]; } - // Load my k vector into registers float my_k[S_V]; if (valid) { const uint k_off = iq3 * sq3 + global_t * sq2 + iq1 * sq1; @@ -105,17 +88,12 @@ void main() { } } - // Step 2: Build A matrix using shared memory broadcast for k[j] - // A[i][j] = -beta[i] * dot(k[i], k[j]) * exp(g_cumsum[i] - g_cumsum[j]) for i > j - // Initialize to zero for (uint j = 0; j < CHUNK_SIZE; j++) { A(tid, j) = 0.0; } barrier(); - // For each column j, broadcast k[j] via shared memory, all threads compute their row for (uint j = 0; j < chunk_len; j++) { - // Broadcast k[j] — need multiple passes when S_V > CHUNK_SIZE { const uint j_global = chunk_start + j; const uint kj_off = iq3 * sq3 + j_global * sq2 + iq1 * sq1; @@ -140,12 +118,8 @@ void main() { barrier(); } - // Step 3: Forward substitution T = (I + A)^{-1} - // Process row by row. For row i, all threads j < i compute in parallel: - // T[i][j] += sum_m T[i][m] * T[m][j] for m in [j..i-1] - // The A matrix is modified in-place to become T. + // Forward substitution: T = (I + A)^{-1}, in-place for (uint i = 1; i < chunk_len; i++) { - // Each thread with tid < i computes T[i][tid] if (tid < i) { float sum = 0.0; for (uint m = tid; m < i; m++) { @@ -156,19 +130,12 @@ void main() { barrier(); } - // Add identity if (valid) { A(tid, tid) = 1.0; } barrier(); - // Step 4+5: Compute W and U via shared memory broadcast + register accumulation - // U[tid][d] = sum_j T[tid][j] * beta[j] * v[j][d] - // W[tid][d] = sum_j T[tid][j] * beta[j] * exp(g_cumsum[j]) * k[j][d] - // - // For each j, broadcast k[j]*exp(g[j]) and v[j] via shared memory. - // Accumulate in d-tiles of 32 to limit register pressure. - + // W and U via tiled broadcast accumulation const uint out_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V; const uint TILE_D = 32; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp index 533f5f4730..31fba4522a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp @@ -1,27 +1,13 @@ #version 450 #extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_KHR_shader_subgroup_basic : enable -#extension GL_KHR_shader_subgroup_arithmetic : enable #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #include "types.glsl" -// Coopmat output kernel for chunked gated delta net -// -// Phase 1: A = Q @ K^T (coopmat, f16→f32) -// Phase 2: Decay mask → sh_adecay (f16) + vnew → sh_kv (f16, pre-scaled) -// Pass 1: O_inter = Q @ S → dst (scalar, 128 threads) -// Pass 2: O_intra = A_decayed @ vnew → dst (coopmat GEMM, full chunks) -// Partial last chunk: scalar fallback. 3 barriers total. -// -// Grid: (n_chunks * H, n_seqs, 1) -// Workgroup: WG_SIZE threads = 4 subgroups - layout(constant_id = 0) const uint WG_SIZE = 256; layout(constant_id = 1) const uint CHUNK_SIZE = 64; layout(constant_id = 2) const uint S_V = 128; @@ -51,16 +37,16 @@ const uint TM = 16; const uint TN = 16; const uint TK = 16; -const uint C_TILES = CHUNK_SIZE / TM; // 4 -const uint D_TILES = S_V / TN; // 8 +const uint C_TILES = CHUNK_SIZE / TM; +const uint D_TILES = S_V / TN; // Shared memory strides in f16vec4 units, padded for bank conflicts -const uint QK_STRIDE = S_V / 4 + 2; // 34 -const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2; // 18 +const uint QK_STRIDE = S_V / 4 + 2; +const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2; -shared f16vec4 sh_q[CHUNK_SIZE * QK_STRIDE]; // Q (f16) for coopmat Phase 1 -shared f16vec4 sh_kv[CHUNK_SIZE * QK_STRIDE]; // K (f16) for coopmat Phase 1 -shared vec4 sh_attn[CHUNK_SIZE * ATTN_V4_STRIDE]; // attention matrix (f32) +shared f16vec4 sh_q[CHUNK_SIZE * QK_STRIDE]; +shared f16vec4 sh_kv[CHUNK_SIZE * QK_STRIDE]; +shared vec4 sh_attn[CHUNK_SIZE * ATTN_V4_STRIDE]; shared f16vec4 sh_adecay[CHUNK_SIZE * ATTN_V4_STRIDE]; shared float sh_gcum[CHUNK_SIZE]; @@ -80,10 +66,6 @@ void main() { const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start); const float scale = 1.0 / sqrt(float(S_V)); - // ================================================================ - // Load Q, K, gcum to shared memory - // ================================================================ - if (tid < CHUNK_SIZE) { const uint gcum_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE; sh_gcum[tid] = (tid < chunk_len) ? gcum_in[gcum_base + tid] : 0.0; @@ -105,10 +87,7 @@ void main() { barrier(); - // ================================================================ - // Phase 1: A = Q @ K^T [C×D] × [D×C] → [C×C] (coopmat) - // ================================================================ - + // A = Q @ K^T (coopmat) coopmat A_acc[C_TILES]; [[unroll]] for (uint tj = 0; tj < C_TILES; tj++) { A_acc[tj] = coopmat(0.0); @@ -133,7 +112,6 @@ void main() { } } - // Store A to sh_attn as f32 [[unroll]] for (uint tj = 0; tj < C_TILES; tj++) { coopMatStore(A_acc[tj], sh_attn, sg_id * TM * ATTN_V4_STRIDE + tj * (TN / 4), @@ -142,16 +120,10 @@ void main() { barrier(); - // ================================================================ - // Phase 2: Decay mask + vnew load (all 256 threads) - // Pass 1: Inter-chunk Q@S → dst (128 active threads) - // No shared-memory write conflicts — runs without intermediate barrier. - // ================================================================ - const uint wu_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V; const uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; - // Phase 2a: A_decayed = causal_decay_mask(A) → sh_adecay (f16) + // Causal decay mask (f16) for (uint idx = tid; idx < CHUNK_SIZE * (CHUNK_SIZE / 4); idx += WG_SIZE) { const uint t = idx / (CHUNK_SIZE / 4); const uint j4 = idx % (CHUNK_SIZE / 4); @@ -170,7 +142,7 @@ void main() { sh_adecay[t * ATTN_V4_STRIDE + j4] = val; } - // Phase 2b: vnew → sh_kv (f16, pre-scaled by 1/√S_V) + // vnew to f16, pre-scaled for (uint idx = tid; idx < CHUNK_SIZE * (S_V / 4); idx += WG_SIZE) { const uint row = idx / (S_V / 4); const uint col4 = idx % (S_V / 4); @@ -186,7 +158,7 @@ void main() { sh_kv[row * QK_STRIDE + col4] = val; } - // Pass 1: Inter-chunk (128 active threads, write directly to dst) + // O_inter = Q @ state { const uint col = tid; const bool col_active = (col < S_V); @@ -216,12 +188,7 @@ void main() { barrier(); - // ================================================================ - // Pass 2: Intra-chunk A_decayed[C×C] @ vnew[C×S_V] → [C×S_V] - // Full chunks: coopmat GEMM (accumulates onto inter results in dst). - // Partial (last) chunk: scalar fallback. - // ================================================================ - + // O_intra = A_decayed @ vnew (coopmat GEMM, scalar fallback for partial chunks) if (chunk_len == CHUNK_SIZE) { coopmat A_mat; coopmat V_mat;