vulkan: add chunked parallel kernel infrastructure for GATED_DELTA_NET

Three-dispatch chunked pipeline for prompt processing acceleration:
intra-chunk WY decomposition, inter-chunk state propagation, output
combination. Currently disabled (threshold=UINT32_MAX).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Progeny Alpha 2026-03-10 22:51:11 -04:00
parent 463b6a963c
commit 949a7e86d3
6 changed files with 600 additions and 11 deletions

View File

@ -827,6 +827,9 @@ struct vk_device_struct {
vk_pipeline pipeline_rwkv_wkv7_f32;
// [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
vk_pipeline pipeline_gated_delta_net[3][2];
vk_pipeline pipeline_gated_delta_net_chunk_intra;
vk_pipeline pipeline_gated_delta_net_chunk_inter;
vk_pipeline pipeline_gated_delta_net_chunk_output;
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
@ -1468,6 +1471,18 @@ struct vk_op_gated_delta_net_push_constants {
float scale;
};
struct vk_op_gated_delta_net_chunk_push_constants {
uint32_t H;
uint32_t n_tokens;
uint32_t n_seqs;
uint32_t sq1, sq2, sq3;
uint32_t sv1, sv2, sv3;
uint32_t sb1, sb2, sb3;
uint32_t neq1, rq3;
uint32_t n_chunks;
uint32_t s_off;
};
struct vk_op_ssm_scan_push_constants {
uint32_t nb02, nb03, nb12, nb13;
uint32_t nb21, nb22, nb31;
@ -4599,6 +4614,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
}
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_intra, "gated_delta_net_chunk_intra_f32_d128",
gated_delta_net_chunk_intra_f32_len, gated_delta_net_chunk_intra_f32_data, "main",
8, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {128, 64}, 1);
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_inter, "gated_delta_net_chunk_inter_f32_d128",
gated_delta_net_chunk_inter_f32_len, gated_delta_net_chunk_inter_f32_data, "main",
9, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {128, 64}, 1);
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_output, "gated_delta_net_chunk_output_f32_d128",
gated_delta_net_chunk_output_f32_len, gated_delta_net_chunk_output_f32_data, "main",
6, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {128, 64}, 1);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
@ -10373,9 +10398,13 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
);
}
static constexpr uint32_t GDN_CHUNK_SIZE = 64;
static constexpr uint32_t GDN_CHUNK_THRESHOLD = UINT32_MAX; // Disabled
static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const ggml_tensor * src_q = dst->src[0];
const ggml_tensor * src_v = dst->src[2];
const ggml_tensor * src_g = dst->src[3];
const ggml_tensor * src_beta = dst->src[4];
GGML_ASSERT(dst->buffer != nullptr);
@ -10386,11 +10415,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
const uint32_t n_seqs = (uint32_t)src_v->ne[3];
const uint32_t s_off = S_v * H * n_tokens * n_seqs;
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
const bool kda = (src_g->ne[0] == (int64_t)S_v);
const bool use_chunked = !kda && S_v == 128 && n_tokens > GDN_CHUNK_THRESHOLD;
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer src_buf[6] = {};
@ -10411,19 +10437,104 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
const uint32_t neq1 = (uint32_t)src_q->ne[1];
const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]);
const float scale = 1.0f / sqrtf((float)S_v);
const vk_op_gated_delta_net_push_constants pc = {
H, n_tokens, n_seqs, s_off,
if (!use_chunked) {
// Autoregressive path (optimal for TG / small n_tokens)
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
const float scale = 1.0f / sqrtf((float)S_v);
const vk_op_gated_delta_net_push_constants pc = {
H, n_tokens, n_seqs, s_off,
sq1, sq2, sq3,
sv1, sv2, sv3,
sb1, sb2, sb3,
neq1, rq3,
scale
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, { H, n_seqs, 1u });
return;
}
// Chunked parallel path (PP acceleration)
const uint32_t n_chunks = (n_tokens + GDN_CHUNK_SIZE - 1) / GDN_CHUNK_SIZE;
vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra;
vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter;
vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output;
ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1);
ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1);
ggml_pipeline_request_descriptor_sets(ctx, pl_output, 1);
// Scratch buffer layout within prealloc_split_k
const size_t wu_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * S_v * sizeof(float);
const size_t d_size = (size_t)n_seqs * n_chunks * H * sizeof(float);
const size_t g_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * sizeof(float);
const size_t h_size = (size_t)n_seqs * n_chunks * H * S_v * S_v * sizeof(float);
const size_t w_off = 0;
const size_t u_off = wu_size;
const size_t vn_off = 2 * wu_size;
const size_t dec_off = 3 * wu_size;
const size_t gcum_off = dec_off + d_size;
const size_t h_off = gcum_off + g_size;
const size_t total_scratch = h_off + h_size;
if (ctx->prealloc_size_split_k < total_scratch) {
ctx->prealloc_size_split_k = total_scratch;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size };
vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size };
vk_subbuffer scratch_vnew = { ctx->prealloc_split_k, vn_off, wu_size };
vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size };
vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size };
vk_subbuffer scratch_h = { ctx->prealloc_split_k, h_off, h_size };
const vk_op_gated_delta_net_chunk_push_constants pc = {
H, n_tokens, n_seqs,
sq1, sq2, sq3,
sv1, sv2, sv3,
sb1, sb2, sb3,
neq1, rq3,
scale
n_chunks, s_off
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
// Dispatch 1: Intra-chunk (parallel across chunks)
// Bindings: K, V, G, Beta, W_out, U_out, Decay_out, GCum_out
ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra,
{src_buf[1], src_buf[2], src_buf[3], src_buf[4],
scratch_w, scratch_u, scratch_dec, scratch_gcum},
pc, { n_chunks * H, n_seqs, 1u });
ggml_vk_sync_buffers(ctx, subctx);
// Dispatch 2: Inter-chunk state propagation (sequential across chunks)
// Bindings: K, W, U, Decay, GCum, State, H_out, VNew_out, Final(dst)
ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter,
{src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum,
src_buf[5], scratch_h, scratch_vnew, dst_buf},
pc, { H, n_seqs, 1u });
ggml_vk_sync_buffers(ctx, subctx);
// Dispatch 3: Output (parallel across chunks)
// Bindings: Q, K, H, VNew, GCum, Dst
ggml_vk_dispatch_pipeline(ctx, subctx, pl_output,
{src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf},
pc, { n_chunks * H, n_seqs, 1u });
ctx->prealloc_split_k_need_sync = true;
}
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {

View File

@ -0,0 +1,126 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
// Inter-chunk state propagation for chunked gated delta net
//
// Sequential across chunks, parallel across state columns.
// For each chunk c:
// 1. Store state snapshot h[c] for output kernel
// 2. v_corrected = U - W @ S (C x d)
// 3. S_next = exp(g_total) * S + K_gated^T @ v_corrected (d x d)
//
// where K_gated[t] = k[t] * exp(g_total - g_cumsum[t])
//
// Grid: (H, n_seqs, 1)
// Workgroup: S_V threads (one per state column)
layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint CHUNK_SIZE = 64;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint H;
uint n_tokens;
uint n_seqs;
uint sq1, sq2, sq3;
uint sv1, sv2, sv3;
uint sb1, sb2, sb3;
uint neq1, rq3;
uint n_chunks;
uint s_off;
};
layout(binding = 0) readonly buffer KBuf { float k_in[]; };
layout(binding = 1) readonly buffer WBuf { float w_in[]; };
layout(binding = 2) readonly buffer UBuf { float u_in[]; };
layout(binding = 3) readonly buffer DecayBuf { float decay_in[]; };
layout(binding = 4) readonly buffer GCumBuf { float gcum_in[]; };
layout(binding = 5) readonly buffer StateBuf { float state_in[]; };
layout(binding = 6) writeonly buffer HBuf { float h_out[]; };
layout(binding = 7) writeonly buffer VNewBuf { float vnew_out[]; };
layout(binding = 8) buffer FinalBuf { float final_out[]; };
shared float s_w[S_V];
shared float s_kg[S_V];
void main() {
const uint head_id = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint col = gl_LocalInvocationID.x;
if (col >= S_V) return;
const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;
float state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = state_in[state_base + i * S_V + col];
}
for (uint c = 0; c < n_chunks; c++) {
const uint chunk_start = c * CHUNK_SIZE;
const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start);
const uint h_base = ((seq_id * n_chunks + c) * H + head_id) * state_size;
[[unroll]] for (uint i = 0; i < S_V; i++) {
h_out[h_base + i * S_V + col] = state[i];
}
const uint wu_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE * S_V;
const uint gcum_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE;
const uint decay_idx = (seq_id * n_chunks + c) * H + head_id;
const float g_total = decay_in[decay_idx];
float delta[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
delta[i] = 0.0;
}
for (uint t = 0; t < chunk_len; t++) {
s_w[col] = w_in[wu_base + t * S_V + col];
barrier();
float ws = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
ws += dot(
vec4(s_w[i], s_w[i+1], s_w[i+2], s_w[i+3]),
vec4(state[i], state[i+1], state[i+2], state[i+3])
);
}
float vnew = u_in[wu_base + t * S_V + col] - ws;
vnew_out[wu_base + t * S_V + col] = vnew;
// K_gated[t] = k[t] * exp(g_total - g_cumsum[t])
float g_cumsum_t = gcum_in[gcum_base + t];
float decay_factor = exp(g_total - g_cumsum_t);
const uint t_global = chunk_start + t;
const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1;
s_kg[col] = k_in[k_off + col] * decay_factor;
barrier();
[[unroll]] for (uint i = 0; i < S_V; i++) {
delta[i] += s_kg[i] * vnew;
}
barrier();
}
float total_decay = exp(g_total);
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = total_decay * state[i] + delta[i];
}
}
// Write final state to dst at s_off
[[unroll]] for (uint i = 0; i < S_V; i++) {
final_out[s_off + state_base + i * S_V + col] = state[i];
}
}

View File

@ -0,0 +1,222 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
// Intra-chunk kernel for chunked gated delta net (non-KDA scalar gate)
//
// For each chunk of C=64 tokens, computes W and U using the WY representation.
// Uses a single A matrix with gates, compensates W by multiplying k by exp(g).
//
// Algorithm (FLA-equivalent):
// 1. g_cumsum[j] = cumsum(g[0..j]) within chunk (log-space)
// 2. A[i][j] = -beta[i] * dot(k[i], k[j]) * exp(g_cumsum[i] - g_cumsum[j]) for i > j
// 3. T = (I + A)^{-1} via forward substitution (row-by-row parallel)
// 4. U[i] = sum_j T[i][j] * beta[j] * v[j]
// 5. W[i] = sum_j T[i][j] * beta[j] * exp(g_cumsum[j]) * k[j]
// 6. Output g_cumsum[C-1] as total chunk decay
//
// Grid: (n_chunks * H, n_seqs, 1)
// Workgroup: CHUNK_SIZE threads (one per token in chunk)
layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint CHUNK_SIZE = 64;
layout(local_size_x_id = 1, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint H;
uint n_tokens;
uint n_seqs;
uint sq1, sq2, sq3;
uint sv1, sv2, sv3;
uint sb1, sb2, sb3;
uint neq1, rq3;
uint n_chunks;
uint s_off;
};
layout(binding = 0) readonly buffer KBuf { float k_in[]; };
layout(binding = 1) readonly buffer VBuf { float v_in[]; };
layout(binding = 2) readonly buffer GBuf { float g_in[]; };
layout(binding = 3) readonly buffer BetaBuf { float beta_in[]; };
layout(binding = 4) writeonly buffer WBuf { float w_out[]; };
layout(binding = 5) writeonly buffer UBuf { float u_out[]; };
layout(binding = 6) writeonly buffer DecayBuf { float decay_out[]; };
layout(binding = 7) writeonly buffer GCumBuf { float gcum_out[]; }; // per-token g_cumsum
shared float s_A[CHUNK_SIZE * CHUNK_SIZE];
shared float s_decay[CHUNK_SIZE];
shared float s_beta[CHUNK_SIZE];
shared float s_k_broadcast[S_V];
shared float s_v_broadcast[S_V];
#define A(i,j) s_A[(i) * CHUNK_SIZE + (j)]
void main() {
const uint chunk_head = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint tid = gl_LocalInvocationID.x;
const uint head_id = chunk_head % H;
const uint chunk_id = chunk_head / H;
const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;
const uint chunk_start = chunk_id * CHUNK_SIZE;
const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start);
const uint global_t = chunk_start + tid;
const bool valid = tid < chunk_len;
// Load beta and gate
if (valid) {
const uint gb_off = seq_id * sb3 + global_t * sb2 + head_id * sb1;
s_beta[tid] = beta_in[gb_off];
s_decay[tid] = g_in[gb_off];
} else {
s_beta[tid] = 0.0;
s_decay[tid] = 0.0;
}
barrier();
// Step 1: Prefix sum of log-gates
if (tid == 0) {
for (uint i = 1; i < chunk_len; i++) {
s_decay[i] += s_decay[i - 1];
}
}
barrier();
// Output per-token g_cumsum for inter-chunk and output kernels
if (valid) {
const uint gcum_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE;
gcum_out[gcum_base + tid] = s_decay[tid];
}
// Load my k vector into registers
float my_k[S_V];
if (valid) {
const uint k_off = iq3 * sq3 + global_t * sq2 + iq1 * sq1;
[[unroll]] for (uint d = 0; d < S_V; d++) {
my_k[d] = k_in[k_off + d];
}
} else {
[[unroll]] for (uint d = 0; d < S_V; d++) {
my_k[d] = 0.0;
}
}
// Step 2: Build A matrix using shared memory broadcast for k[j]
// A[i][j] = -beta[i] * dot(k[i], k[j]) * exp(g_cumsum[i] - g_cumsum[j]) for i > j
// Initialize to zero
for (uint j = 0; j < CHUNK_SIZE; j++) {
A(tid, j) = 0.0;
}
barrier();
// For each column j, broadcast k[j] via shared memory, all threads compute their row
for (uint j = 0; j < chunk_len; j++) {
// Broadcast k[j] — need multiple passes when S_V > CHUNK_SIZE
{
const uint j_global = chunk_start + j;
const uint kj_off = iq3 * sq3 + j_global * sq2 + iq1 * sq1;
for (uint d = tid; d < S_V; d += CHUNK_SIZE) {
s_k_broadcast[d] = k_in[kj_off + d];
}
}
barrier();
if (valid && tid > j) {
float dot_kk = 0.0;
[[unroll]] for (uint d = 0; d < S_V; d += 4) {
dot_kk += dot(
vec4(my_k[d], my_k[d+1], my_k[d+2], my_k[d+3]),
vec4(s_k_broadcast[d], s_k_broadcast[d+1],
s_k_broadcast[d+2], s_k_broadcast[d+3])
);
}
float decay_factor = exp(s_decay[tid] - s_decay[j]);
A(tid, j) = -s_beta[tid] * dot_kk * decay_factor;
}
barrier();
}
// Step 3: Forward substitution T = (I + A)^{-1}
// Process row by row. For row i, all threads j < i compute in parallel:
// T[i][j] += sum_m T[i][m] * T[m][j] for m in [j..i-1]
// The A matrix is modified in-place to become T.
for (uint i = 1; i < chunk_len; i++) {
// Each thread with tid < i computes T[i][tid]
if (tid < i) {
float sum = 0.0;
for (uint m = tid; m < i; m++) {
sum += A(i, m) * A(m, tid);
}
A(i, tid) += sum;
}
barrier();
}
// Add identity
if (valid) {
A(tid, tid) = 1.0;
}
barrier();
// Step 4+5: Compute W and U via shared memory broadcast + register accumulation
// U[tid][d] = sum_j T[tid][j] * beta[j] * v[j][d]
// W[tid][d] = sum_j T[tid][j] * beta[j] * exp(g_cumsum[j]) * k[j][d]
//
// For each j, broadcast k[j]*exp(g[j]) and v[j] via shared memory.
// Accumulate in d-tiles of 32 to limit register pressure.
const uint out_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V;
const uint TILE_D = 32;
for (uint d_start = 0; d_start < S_V; d_start += TILE_D) {
float my_w[TILE_D];
float my_u[TILE_D];
for (uint d = 0; d < TILE_D; d++) {
my_w[d] = 0.0;
my_u[d] = 0.0;
}
for (uint j = 0; j < chunk_len; j++) {
const uint j_global = chunk_start + j;
const uint kj_off = iq3 * sq3 + j_global * sq2 + iq1 * sq1;
const uint vj_off = seq_id * sv3 + j_global * sv2 + head_id * sv1;
float eg = exp(s_decay[j]);
// Broadcast tile of k[j] and v[j]
for (uint d = tid; d < S_V; d += CHUNK_SIZE) {
if (d >= d_start && d < d_start + TILE_D) {
s_k_broadcast[d] = k_in[kj_off + d] * eg;
s_v_broadcast[d] = v_in[vj_off + d];
}
}
barrier();
if (valid && j <= tid) {
float t_beta = A(tid, j) * s_beta[j];
[[unroll]] for (uint d = 0; d < TILE_D; d++) {
my_w[d] += t_beta * s_k_broadcast[d_start + d];
my_u[d] += t_beta * s_v_broadcast[d_start + d];
}
}
barrier();
}
// Write tile to global memory
if (valid) {
for (uint d = 0; d < TILE_D; d++) {
w_out[out_base + tid * S_V + d_start + d] = my_w[d];
u_out[out_base + tid * S_V + d_start + d] = my_u[d];
}
}
}
// Output total chunk decay
if (tid == 0) {
const uint decay_idx = (seq_id * n_chunks + chunk_id) * H + head_id;
decay_out[decay_idx] = s_decay[chunk_len - 1];
}
}

View File

@ -0,0 +1,124 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
// Output kernel for chunked gated delta net
//
// For each chunk, combines inter-chunk and intra-chunk contributions:
// o[t] = q[t]^T @ (exp(g_cumsum[t]) * S_chunk) + causal_attn(q, k, v_corrected)
//
// Grid: (n_chunks * H, n_seqs, 1)
// Workgroup: S_V threads (one per output column)
layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint CHUNK_SIZE = 64;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint H;
uint n_tokens;
uint n_seqs;
uint sq1, sq2, sq3;
uint sv1, sv2, sv3;
uint sb1, sb2, sb3;
uint neq1, rq3;
uint n_chunks;
uint s_off;
};
layout(binding = 0) readonly buffer QBuf { float q_in[]; };
layout(binding = 1) readonly buffer KBuf { float k_in[]; };
layout(binding = 2) readonly buffer HBuf { float h_in[]; };
layout(binding = 3) readonly buffer VNewBuf { float vnew_in[]; };
layout(binding = 4) readonly buffer GCumBuf { float gcum_in[]; };
layout(binding = 5) buffer DstBuf { float dst[]; };
shared float s_q[S_V];
shared float s_k[S_V];
shared float s_gcum[CHUNK_SIZE];
void main() {
const uint chunk_head = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint col = gl_LocalInvocationID.x;
if (col >= S_V) return;
const uint head_id = chunk_head % H;
const uint chunk_id = chunk_head / H;
const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;
const uint chunk_start = chunk_id * CHUNK_SIZE;
const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start);
const float scale = 1.0 / sqrt(float(S_V));
const uint state_size = S_V * S_V;
const uint h_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * state_size;
float state_col[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
state_col[i] = h_in[h_base + i * S_V + col];
}
const uint wu_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V;
const uint gcum_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE;
if (col < CHUNK_SIZE) {
s_gcum[col] = (col < chunk_len) ? gcum_in[gcum_base + col] : 0.0;
}
// Preload vnew[j][col] into registers
float my_vnew[CHUNK_SIZE];
for (uint j = 0; j < chunk_len; j++) {
my_vnew[j] = vnew_in[wu_base + j * S_V + col];
}
barrier();
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
for (uint t = 0; t < chunk_len; t++) {
const uint t_global = chunk_start + t;
const uint q_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1;
s_q[col] = q_in[q_off + col];
barrier();
// Inter-chunk: o_inter = q^T @ (exp(g_cumsum[t]) * S)
float decay_t = exp(s_gcum[t]);
float o_inter = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
o_inter += dot(
vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]),
decay_t * vec4(state_col[i], state_col[i+1], state_col[i+2], state_col[i+3])
);
}
// Intra-chunk: o_intra = sum_{j<=t} dot(q[t], k[j]) * decay_mask * vnew[j][col]
float o_intra = 0.0;
for (uint j = 0; j <= t; j++) {
const uint j_global = chunk_start + j;
const uint kj_off = iq3 * sq3 + j_global * sq2 + iq1 * sq1;
s_k[col] = k_in[kj_off + col];
barrier();
float qk = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
qk += dot(
vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]),
vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
);
}
float mask = exp(s_gcum[t] - s_gcum[j]);
o_intra += qk * mask * my_vnew[j];
barrier();
}
dst[attn_off + t_global * S_V * H + col] = (o_inter + o_intra) * scale;
}
}

View File

@ -988,6 +988,9 @@ void process_shaders() {
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
string_to_spv("gated_delta_net_chunk_intra_f32", "gated_delta_net_chunk_intra.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("gated_delta_net_chunk_inter_f32", "gated_delta_net_chunk_inter.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("gated_delta_net_chunk_output_f32", "gated_delta_net_chunk_output.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

View File

@ -8682,6 +8682,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 4, 1)); // chunked path: S_V=128, n_tokens=4
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 64, 1)); // chunked path: full chunk
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 128, 1)); // chunked path: 2 chunks
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));