From 949a7e86d34f6a5688ae40304f2a7056618ffcea Mon Sep 17 00:00:00 2001 From: Progeny Alpha Date: Tue, 10 Mar 2026 22:51:11 -0400 Subject: [PATCH] vulkan: add chunked parallel kernel infrastructure for GATED_DELTA_NET Three-dispatch chunked pipeline for prompt processing acceleration: intra-chunk WY decomposition, inter-chunk state propagation, output combination. Currently disabled (threshold=UINT32_MAX). Co-Authored-By: Claude Opus 4.6 --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 133 ++++++++++- .../gated_delta_net_chunk_inter.comp | 126 ++++++++++ .../gated_delta_net_chunk_intra.comp | 222 ++++++++++++++++++ .../gated_delta_net_chunk_output.comp | 124 ++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 3 + tests/test-backend-ops.cpp | 3 + 6 files changed, 600 insertions(+), 11 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3c81805b84..754fd7900c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -827,6 +827,9 @@ struct vk_device_struct { vk_pipeline pipeline_rwkv_wkv7_f32; // [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128 vk_pipeline pipeline_gated_delta_net[3][2]; + vk_pipeline pipeline_gated_delta_net_chunk_intra; + vk_pipeline pipeline_gated_delta_net_chunk_inter; + vk_pipeline pipeline_gated_delta_net_chunk_output; vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; @@ -1468,6 +1471,18 @@ struct vk_op_gated_delta_net_push_constants { float scale; }; +struct vk_op_gated_delta_net_chunk_push_constants { + uint32_t H; + uint32_t n_tokens; + uint32_t n_seqs; + uint32_t sq1, sq2, sq3; + uint32_t sv1, sv2, sv3; + uint32_t sb1, sb2, sb3; + uint32_t neq1, rq3; + uint32_t n_chunks; + uint32_t s_off; +}; + struct vk_op_ssm_scan_push_constants { uint32_t nb02, nb03, nb12, nb13; uint32_t nb21, nb22, nb31; @@ -4599,6 +4614,16 @@ static void ggml_vk_load_shaders(vk_device& device) { } } + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_intra, "gated_delta_net_chunk_intra_f32_d128", + gated_delta_net_chunk_intra_f32_len, gated_delta_net_chunk_intra_f32_data, "main", + 8, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {128, 64}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_inter, "gated_delta_net_chunk_inter_f32_d128", + gated_delta_net_chunk_inter_f32_len, gated_delta_net_chunk_inter_f32_data, "main", + 9, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {128, 64}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_output, "gated_delta_net_chunk_output_f32_d128", + gated_delta_net_chunk_output_f32_len, gated_delta_net_chunk_output_f32_data, "main", + 6, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {128, 64}, 1); + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true); ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true); @@ -10373,9 +10398,13 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ); } +static constexpr uint32_t GDN_CHUNK_SIZE = 64; +static constexpr uint32_t GDN_CHUNK_THRESHOLD = UINT32_MAX; // Disabled + static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; + const ggml_tensor * src_g = dst->src[3]; const ggml_tensor * src_beta = dst->src[4]; GGML_ASSERT(dst->buffer != nullptr); @@ -10386,11 +10415,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_seqs = (uint32_t)src_v->ne[3]; const uint32_t s_off = S_v * H * n_tokens * n_seqs; - - vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); - GGML_ASSERT(pipeline != nullptr); - - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + const bool kda = (src_g->ne[0] == (int64_t)S_v); + const bool use_chunked = !kda && S_v == 128 && n_tokens > GDN_CHUNK_THRESHOLD; vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); vk_subbuffer src_buf[6] = {}; @@ -10411,19 +10437,104 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t neq1 = (uint32_t)src_q->ne[1]; const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]); - const float scale = 1.0f / sqrtf((float)S_v); - const vk_op_gated_delta_net_push_constants pc = { - H, n_tokens, n_seqs, s_off, + if (!use_chunked) { + // Autoregressive path (optimal for TG / small n_tokens) + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + const float scale = 1.0f / sqrtf((float)S_v); + const vk_op_gated_delta_net_push_constants pc = { + H, n_tokens, n_seqs, s_off, + sq1, sq2, sq3, + sv1, sv2, sv3, + sb1, sb2, sb3, + neq1, rq3, + scale + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + pc, { H, n_seqs, 1u }); + return; + } + + // Chunked parallel path (PP acceleration) + const uint32_t n_chunks = (n_tokens + GDN_CHUNK_SIZE - 1) / GDN_CHUNK_SIZE; + + vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra; + vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter; + vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output; + + ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1); + ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1); + ggml_pipeline_request_descriptor_sets(ctx, pl_output, 1); + + // Scratch buffer layout within prealloc_split_k + const size_t wu_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * S_v * sizeof(float); + const size_t d_size = (size_t)n_seqs * n_chunks * H * sizeof(float); + const size_t g_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * sizeof(float); + const size_t h_size = (size_t)n_seqs * n_chunks * H * S_v * S_v * sizeof(float); + + const size_t w_off = 0; + const size_t u_off = wu_size; + const size_t vn_off = 2 * wu_size; + const size_t dec_off = 3 * wu_size; + const size_t gcum_off = dec_off + d_size; + const size_t h_off = gcum_off + g_size; + const size_t total_scratch = h_off + h_size; + + if (ctx->prealloc_size_split_k < total_scratch) { + ctx->prealloc_size_split_k = total_scratch; + ggml_vk_preallocate_buffers(ctx, subctx); + } + + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size }; + vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size }; + vk_subbuffer scratch_vnew = { ctx->prealloc_split_k, vn_off, wu_size }; + vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size }; + vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size }; + vk_subbuffer scratch_h = { ctx->prealloc_split_k, h_off, h_size }; + + const vk_op_gated_delta_net_chunk_push_constants pc = { + H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neq1, rq3, - scale + n_chunks, s_off }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + // Dispatch 1: Intra-chunk (parallel across chunks) + // Bindings: K, V, G, Beta, W_out, U_out, Decay_out, GCum_out + ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra, + {src_buf[1], src_buf[2], src_buf[3], src_buf[4], + scratch_w, scratch_u, scratch_dec, scratch_gcum}, + pc, { n_chunks * H, n_seqs, 1u }); + + ggml_vk_sync_buffers(ctx, subctx); + + // Dispatch 2: Inter-chunk state propagation (sequential across chunks) + // Bindings: K, W, U, Decay, GCum, State, H_out, VNew_out, Final(dst) + ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter, + {src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum, + src_buf[5], scratch_h, scratch_vnew, dst_buf}, pc, { H, n_seqs, 1u }); + + ggml_vk_sync_buffers(ctx, subctx); + + // Dispatch 3: Output (parallel across chunks) + // Bindings: Q, K, H, VNew, GCum, Dst + ggml_vk_dispatch_pipeline(ctx, subctx, pl_output, + {src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf}, + pc, { n_chunks * H, n_seqs, 1u }); + + ctx->prealloc_split_k_need_sync = true; } static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp new file mode 100644 index 0000000000..11cd0e18a8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp @@ -0,0 +1,126 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +// Inter-chunk state propagation for chunked gated delta net +// +// Sequential across chunks, parallel across state columns. +// For each chunk c: +// 1. Store state snapshot h[c] for output kernel +// 2. v_corrected = U - W @ S (C x d) +// 3. S_next = exp(g_total) * S + K_gated^T @ v_corrected (d x d) +// +// where K_gated[t] = k[t] * exp(g_total - g_cumsum[t]) +// +// Grid: (H, n_seqs, 1) +// Workgroup: S_V threads (one per state column) + +layout(constant_id = 0) const uint S_V = 128; +layout(constant_id = 1) const uint CHUNK_SIZE = 64; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint H; + uint n_tokens; + uint n_seqs; + uint sq1, sq2, sq3; + uint sv1, sv2, sv3; + uint sb1, sb2, sb3; + uint neq1, rq3; + uint n_chunks; + uint s_off; +}; + +layout(binding = 0) readonly buffer KBuf { float k_in[]; }; +layout(binding = 1) readonly buffer WBuf { float w_in[]; }; +layout(binding = 2) readonly buffer UBuf { float u_in[]; }; +layout(binding = 3) readonly buffer DecayBuf { float decay_in[]; }; +layout(binding = 4) readonly buffer GCumBuf { float gcum_in[]; }; +layout(binding = 5) readonly buffer StateBuf { float state_in[]; }; +layout(binding = 6) writeonly buffer HBuf { float h_out[]; }; +layout(binding = 7) writeonly buffer VNewBuf { float vnew_out[]; }; +layout(binding = 8) buffer FinalBuf { float final_out[]; }; + +shared float s_w[S_V]; +shared float s_kg[S_V]; + +void main() { + const uint head_id = gl_WorkGroupID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint col = gl_LocalInvocationID.x; + + if (col >= S_V) return; + + const uint iq1 = head_id % neq1; + const uint iq3 = seq_id / rq3; + + const uint state_size = S_V * S_V; + const uint state_base = (seq_id * H + head_id) * state_size; + + float state[S_V]; + [[unroll]] for (uint i = 0; i < S_V; i++) { + state[i] = state_in[state_base + i * S_V + col]; + } + + for (uint c = 0; c < n_chunks; c++) { + const uint chunk_start = c * CHUNK_SIZE; + const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start); + + const uint h_base = ((seq_id * n_chunks + c) * H + head_id) * state_size; + [[unroll]] for (uint i = 0; i < S_V; i++) { + h_out[h_base + i * S_V + col] = state[i]; + } + + const uint wu_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE * S_V; + const uint gcum_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE; + + const uint decay_idx = (seq_id * n_chunks + c) * H + head_id; + const float g_total = decay_in[decay_idx]; + + float delta[S_V]; + [[unroll]] for (uint i = 0; i < S_V; i++) { + delta[i] = 0.0; + } + + for (uint t = 0; t < chunk_len; t++) { + s_w[col] = w_in[wu_base + t * S_V + col]; + barrier(); + + float ws = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + ws += dot( + vec4(s_w[i], s_w[i+1], s_w[i+2], s_w[i+3]), + vec4(state[i], state[i+1], state[i+2], state[i+3]) + ); + } + + float vnew = u_in[wu_base + t * S_V + col] - ws; + vnew_out[wu_base + t * S_V + col] = vnew; + + // K_gated[t] = k[t] * exp(g_total - g_cumsum[t]) + float g_cumsum_t = gcum_in[gcum_base + t]; + float decay_factor = exp(g_total - g_cumsum_t); + + const uint t_global = chunk_start + t; + const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1; + s_kg[col] = k_in[k_off + col] * decay_factor; + barrier(); + + [[unroll]] for (uint i = 0; i < S_V; i++) { + delta[i] += s_kg[i] * vnew; + } + barrier(); + } + + float total_decay = exp(g_total); + [[unroll]] for (uint i = 0; i < S_V; i++) { + state[i] = total_decay * state[i] + delta[i]; + } + } + + // Write final state to dst at s_off + [[unroll]] for (uint i = 0; i < S_V; i++) { + final_out[s_off + state_base + i * S_V + col] = state[i]; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp new file mode 100644 index 0000000000..8e1820ca85 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_intra.comp @@ -0,0 +1,222 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +// Intra-chunk kernel for chunked gated delta net (non-KDA scalar gate) +// +// For each chunk of C=64 tokens, computes W and U using the WY representation. +// Uses a single A matrix with gates, compensates W by multiplying k by exp(g). +// +// Algorithm (FLA-equivalent): +// 1. g_cumsum[j] = cumsum(g[0..j]) within chunk (log-space) +// 2. A[i][j] = -beta[i] * dot(k[i], k[j]) * exp(g_cumsum[i] - g_cumsum[j]) for i > j +// 3. T = (I + A)^{-1} via forward substitution (row-by-row parallel) +// 4. U[i] = sum_j T[i][j] * beta[j] * v[j] +// 5. W[i] = sum_j T[i][j] * beta[j] * exp(g_cumsum[j]) * k[j] +// 6. Output g_cumsum[C-1] as total chunk decay +// +// Grid: (n_chunks * H, n_seqs, 1) +// Workgroup: CHUNK_SIZE threads (one per token in chunk) + +layout(constant_id = 0) const uint S_V = 128; +layout(constant_id = 1) const uint CHUNK_SIZE = 64; + +layout(local_size_x_id = 1, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint H; + uint n_tokens; + uint n_seqs; + uint sq1, sq2, sq3; + uint sv1, sv2, sv3; + uint sb1, sb2, sb3; + uint neq1, rq3; + uint n_chunks; + uint s_off; +}; + +layout(binding = 0) readonly buffer KBuf { float k_in[]; }; +layout(binding = 1) readonly buffer VBuf { float v_in[]; }; +layout(binding = 2) readonly buffer GBuf { float g_in[]; }; +layout(binding = 3) readonly buffer BetaBuf { float beta_in[]; }; +layout(binding = 4) writeonly buffer WBuf { float w_out[]; }; +layout(binding = 5) writeonly buffer UBuf { float u_out[]; }; +layout(binding = 6) writeonly buffer DecayBuf { float decay_out[]; }; +layout(binding = 7) writeonly buffer GCumBuf { float gcum_out[]; }; // per-token g_cumsum + +shared float s_A[CHUNK_SIZE * CHUNK_SIZE]; +shared float s_decay[CHUNK_SIZE]; +shared float s_beta[CHUNK_SIZE]; +shared float s_k_broadcast[S_V]; +shared float s_v_broadcast[S_V]; + +#define A(i,j) s_A[(i) * CHUNK_SIZE + (j)] + +void main() { + const uint chunk_head = gl_WorkGroupID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint tid = gl_LocalInvocationID.x; + + const uint head_id = chunk_head % H; + const uint chunk_id = chunk_head / H; + const uint iq1 = head_id % neq1; + const uint iq3 = seq_id / rq3; + + const uint chunk_start = chunk_id * CHUNK_SIZE; + const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start); + const uint global_t = chunk_start + tid; + const bool valid = tid < chunk_len; + + // Load beta and gate + if (valid) { + const uint gb_off = seq_id * sb3 + global_t * sb2 + head_id * sb1; + s_beta[tid] = beta_in[gb_off]; + s_decay[tid] = g_in[gb_off]; + } else { + s_beta[tid] = 0.0; + s_decay[tid] = 0.0; + } + barrier(); + + // Step 1: Prefix sum of log-gates + if (tid == 0) { + for (uint i = 1; i < chunk_len; i++) { + s_decay[i] += s_decay[i - 1]; + } + } + barrier(); + + // Output per-token g_cumsum for inter-chunk and output kernels + if (valid) { + const uint gcum_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE; + gcum_out[gcum_base + tid] = s_decay[tid]; + } + + // Load my k vector into registers + float my_k[S_V]; + if (valid) { + const uint k_off = iq3 * sq3 + global_t * sq2 + iq1 * sq1; + [[unroll]] for (uint d = 0; d < S_V; d++) { + my_k[d] = k_in[k_off + d]; + } + } else { + [[unroll]] for (uint d = 0; d < S_V; d++) { + my_k[d] = 0.0; + } + } + + // Step 2: Build A matrix using shared memory broadcast for k[j] + // A[i][j] = -beta[i] * dot(k[i], k[j]) * exp(g_cumsum[i] - g_cumsum[j]) for i > j + // Initialize to zero + for (uint j = 0; j < CHUNK_SIZE; j++) { + A(tid, j) = 0.0; + } + barrier(); + + // For each column j, broadcast k[j] via shared memory, all threads compute their row + for (uint j = 0; j < chunk_len; j++) { + // Broadcast k[j] — need multiple passes when S_V > CHUNK_SIZE + { + const uint j_global = chunk_start + j; + const uint kj_off = iq3 * sq3 + j_global * sq2 + iq1 * sq1; + for (uint d = tid; d < S_V; d += CHUNK_SIZE) { + s_k_broadcast[d] = k_in[kj_off + d]; + } + } + barrier(); + + if (valid && tid > j) { + float dot_kk = 0.0; + [[unroll]] for (uint d = 0; d < S_V; d += 4) { + dot_kk += dot( + vec4(my_k[d], my_k[d+1], my_k[d+2], my_k[d+3]), + vec4(s_k_broadcast[d], s_k_broadcast[d+1], + s_k_broadcast[d+2], s_k_broadcast[d+3]) + ); + } + float decay_factor = exp(s_decay[tid] - s_decay[j]); + A(tid, j) = -s_beta[tid] * dot_kk * decay_factor; + } + barrier(); + } + + // Step 3: Forward substitution T = (I + A)^{-1} + // Process row by row. For row i, all threads j < i compute in parallel: + // T[i][j] += sum_m T[i][m] * T[m][j] for m in [j..i-1] + // The A matrix is modified in-place to become T. + for (uint i = 1; i < chunk_len; i++) { + // Each thread with tid < i computes T[i][tid] + if (tid < i) { + float sum = 0.0; + for (uint m = tid; m < i; m++) { + sum += A(i, m) * A(m, tid); + } + A(i, tid) += sum; + } + barrier(); + } + + // Add identity + if (valid) { + A(tid, tid) = 1.0; + } + barrier(); + + // Step 4+5: Compute W and U via shared memory broadcast + register accumulation + // U[tid][d] = sum_j T[tid][j] * beta[j] * v[j][d] + // W[tid][d] = sum_j T[tid][j] * beta[j] * exp(g_cumsum[j]) * k[j][d] + // + // For each j, broadcast k[j]*exp(g[j]) and v[j] via shared memory. + // Accumulate in d-tiles of 32 to limit register pressure. + + const uint out_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V; + const uint TILE_D = 32; + + for (uint d_start = 0; d_start < S_V; d_start += TILE_D) { + float my_w[TILE_D]; + float my_u[TILE_D]; + for (uint d = 0; d < TILE_D; d++) { + my_w[d] = 0.0; + my_u[d] = 0.0; + } + + for (uint j = 0; j < chunk_len; j++) { + const uint j_global = chunk_start + j; + const uint kj_off = iq3 * sq3 + j_global * sq2 + iq1 * sq1; + const uint vj_off = seq_id * sv3 + j_global * sv2 + head_id * sv1; + float eg = exp(s_decay[j]); + + // Broadcast tile of k[j] and v[j] + for (uint d = tid; d < S_V; d += CHUNK_SIZE) { + if (d >= d_start && d < d_start + TILE_D) { + s_k_broadcast[d] = k_in[kj_off + d] * eg; + s_v_broadcast[d] = v_in[vj_off + d]; + } + } + barrier(); + + if (valid && j <= tid) { + float t_beta = A(tid, j) * s_beta[j]; + [[unroll]] for (uint d = 0; d < TILE_D; d++) { + my_w[d] += t_beta * s_k_broadcast[d_start + d]; + my_u[d] += t_beta * s_v_broadcast[d_start + d]; + } + } + barrier(); + } + + // Write tile to global memory + if (valid) { + for (uint d = 0; d < TILE_D; d++) { + w_out[out_base + tid * S_V + d_start + d] = my_w[d]; + u_out[out_base + tid * S_V + d_start + d] = my_u[d]; + } + } + } + + // Output total chunk decay + if (tid == 0) { + const uint decay_idx = (seq_id * n_chunks + chunk_id) * H + head_id; + decay_out[decay_idx] = s_decay[chunk_len - 1]; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output.comp new file mode 100644 index 0000000000..e5347e8a05 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output.comp @@ -0,0 +1,124 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +// Output kernel for chunked gated delta net +// +// For each chunk, combines inter-chunk and intra-chunk contributions: +// o[t] = q[t]^T @ (exp(g_cumsum[t]) * S_chunk) + causal_attn(q, k, v_corrected) +// +// Grid: (n_chunks * H, n_seqs, 1) +// Workgroup: S_V threads (one per output column) + +layout(constant_id = 0) const uint S_V = 128; +layout(constant_id = 1) const uint CHUNK_SIZE = 64; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint H; + uint n_tokens; + uint n_seqs; + uint sq1, sq2, sq3; + uint sv1, sv2, sv3; + uint sb1, sb2, sb3; + uint neq1, rq3; + uint n_chunks; + uint s_off; +}; + +layout(binding = 0) readonly buffer QBuf { float q_in[]; }; +layout(binding = 1) readonly buffer KBuf { float k_in[]; }; +layout(binding = 2) readonly buffer HBuf { float h_in[]; }; +layout(binding = 3) readonly buffer VNewBuf { float vnew_in[]; }; +layout(binding = 4) readonly buffer GCumBuf { float gcum_in[]; }; +layout(binding = 5) buffer DstBuf { float dst[]; }; + +shared float s_q[S_V]; +shared float s_k[S_V]; +shared float s_gcum[CHUNK_SIZE]; + +void main() { + const uint chunk_head = gl_WorkGroupID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint col = gl_LocalInvocationID.x; + + if (col >= S_V) return; + + const uint head_id = chunk_head % H; + const uint chunk_id = chunk_head / H; + + const uint iq1 = head_id % neq1; + const uint iq3 = seq_id / rq3; + + const uint chunk_start = chunk_id * CHUNK_SIZE; + const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start); + + const float scale = 1.0 / sqrt(float(S_V)); + + const uint state_size = S_V * S_V; + const uint h_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * state_size; + + float state_col[S_V]; + [[unroll]] for (uint i = 0; i < S_V; i++) { + state_col[i] = h_in[h_base + i * S_V + col]; + } + + const uint wu_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V; + + const uint gcum_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE; + if (col < CHUNK_SIZE) { + s_gcum[col] = (col < chunk_len) ? gcum_in[gcum_base + col] : 0.0; + } + + // Preload vnew[j][col] into registers + float my_vnew[CHUNK_SIZE]; + for (uint j = 0; j < chunk_len; j++) { + my_vnew[j] = vnew_in[wu_base + j * S_V + col]; + } + barrier(); + + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; + + for (uint t = 0; t < chunk_len; t++) { + const uint t_global = chunk_start + t; + + const uint q_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1; + s_q[col] = q_in[q_off + col]; + barrier(); + + // Inter-chunk: o_inter = q^T @ (exp(g_cumsum[t]) * S) + float decay_t = exp(s_gcum[t]); + float o_inter = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + o_inter += dot( + vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]), + decay_t * vec4(state_col[i], state_col[i+1], state_col[i+2], state_col[i+3]) + ); + } + + // Intra-chunk: o_intra = sum_{j<=t} dot(q[t], k[j]) * decay_mask * vnew[j][col] + float o_intra = 0.0; + for (uint j = 0; j <= t; j++) { + const uint j_global = chunk_start + j; + const uint kj_off = iq3 * sq3 + j_global * sq2 + iq1 * sq1; + s_k[col] = k_in[kj_off + col]; + barrier(); + + float qk = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + qk += dot( + vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]), + vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]) + ); + } + + float mask = exp(s_gcum[t] - s_gcum[j]); + o_intra += qk * mask * my_vnew[j]; + + barrier(); + } + + dst[attn_off + t_global * S_V * H + col] = (o_inter + o_intra) * scale; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 4b00ba3deb..d23a2274af 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -988,6 +988,9 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}})); + string_to_spv("gated_delta_net_chunk_intra_f32", "gated_delta_net_chunk_intra.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("gated_delta_net_chunk_inter_f32", "gated_delta_net_chunk_inter.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("gated_delta_net_chunk_output_f32", "gated_delta_net_chunk_output.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index abf914faa1..77f5af394b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8682,6 +8682,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 4, 1)); // chunked path: S_V=128, n_tokens=4 + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 64, 1)); // chunked path: full chunk + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 128, 1)); // chunked path: 2 chunks test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));