Revert "vulkan: fused inter+output kernel for chunked GDN"

This reverts commit 08c355c01f3a298ef943216d4c55367a1c967286.
This commit is contained in:
Progeny Alpha 2026-03-13 19:10:20 -04:00
parent b0323615c9
commit efbde13283
3 changed files with 57 additions and 435 deletions

View File

@ -831,7 +831,6 @@ struct vk_device_struct {
vk_pipeline pipeline_gated_delta_net_chunk_inter; vk_pipeline pipeline_gated_delta_net_chunk_inter;
vk_pipeline pipeline_gated_delta_net_chunk_output; vk_pipeline pipeline_gated_delta_net_chunk_output;
vk_pipeline pipeline_gated_delta_net_chunk_output_cm; vk_pipeline pipeline_gated_delta_net_chunk_output_cm;
vk_pipeline pipeline_gated_delta_net_chunk_fused_cm;
vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32; vk_pipeline pipeline_ssm_conv_f32;
@ -4630,9 +4629,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_output_cm, "gated_delta_net_chunk_output_cm1_f32_d128", ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_output_cm, "gated_delta_net_chunk_output_cm1_f32_d128",
gated_delta_net_chunk_output_cm1_f32_len, gated_delta_net_chunk_output_cm1_f32_data, "main", gated_delta_net_chunk_output_cm1_f32_len, gated_delta_net_chunk_output_cm1_f32_data, "main",
6, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {256, 64, 128}, 1, true); 6, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {256, 64, 128}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_fused_cm, "gated_delta_net_chunk_fused_cm_f32_d128",
gated_delta_net_chunk_fused_cm_f32_len, gated_delta_net_chunk_fused_cm_f32_data, "main",
8, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {256, 64, 128}, 1, true);
} }
if (device->subgroup_arithmetic && device->subgroup_require_full_support) { if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
@ -10474,12 +10470,45 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
// Chunked parallel path (PP acceleration) // Chunked parallel path (PP acceleration)
const uint32_t n_chunks = (n_tokens + GDN_CHUNK_SIZE - 1) / GDN_CHUNK_SIZE; const uint32_t n_chunks = (n_tokens + GDN_CHUNK_SIZE - 1) / GDN_CHUNK_SIZE;
vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra; vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra;
vk_pipeline pl_fused = ctx->device->pipeline_gated_delta_net_chunk_fused_cm; vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter;
vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output_cm
? ctx->device->pipeline_gated_delta_net_chunk_output_cm
: ctx->device->pipeline_gated_delta_net_chunk_output;
const size_t wu_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * S_v * sizeof(float); ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1);
const size_t d_size = (size_t)n_seqs * n_chunks * H * sizeof(float); ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1);
const size_t g_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * sizeof(float); ggml_pipeline_request_descriptor_sets(ctx, pl_output, 1);
// Scratch buffer layout within prealloc_split_k
const size_t wu_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * S_v * sizeof(float);
const size_t d_size = (size_t)n_seqs * n_chunks * H * sizeof(float);
const size_t g_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * sizeof(float);
const size_t h_size = (size_t)n_seqs * n_chunks * H * S_v * S_v * sizeof(float);
const size_t w_off = 0;
const size_t u_off = wu_size;
const size_t vn_off = 2 * wu_size;
const size_t dec_off = 3 * wu_size;
const size_t gcum_off = dec_off + d_size;
const size_t h_off = gcum_off + g_size;
const size_t total_scratch = h_off + h_size;
if (ctx->prealloc_size_split_k < total_scratch) {
ctx->prealloc_size_split_k = total_scratch;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size };
vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size };
vk_subbuffer scratch_vnew = { ctx->prealloc_split_k, vn_off, wu_size };
vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size };
vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size };
vk_subbuffer scratch_h = { ctx->prealloc_split_k, h_off, h_size };
const vk_op_gated_delta_net_chunk_push_constants pc = { const vk_op_gated_delta_net_chunk_push_constants pc = {
H, n_tokens, n_seqs, H, n_tokens, n_seqs,
@ -10490,95 +10519,29 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
n_chunks, s_off n_chunks, s_off
}; };
if (pl_fused) { // Dispatch 1: Intra-chunk (parallel across chunks)
// Fused inter+output path: 2 dispatches, no vnew/h scratch // Bindings: K, V, G, Beta, W_out, U_out, Decay_out, GCum_out
ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1); ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra,
ggml_pipeline_request_descriptor_sets(ctx, pl_fused, 1); {src_buf[1], src_buf[2], src_buf[3], src_buf[4],
scratch_w, scratch_u, scratch_dec, scratch_gcum},
pc, { n_chunks * H, n_seqs, 1u });
const size_t w_off = 0; ggml_vk_sync_buffers(ctx, subctx);
const size_t u_off = wu_size;
const size_t dec_off = 2 * wu_size;
const size_t gcum_off = dec_off + d_size;
const size_t total_scratch = gcum_off + g_size;
if (ctx->prealloc_size_split_k < total_scratch) { // Dispatch 2: Inter-chunk state propagation (sequential across chunks)
ctx->prealloc_size_split_k = total_scratch; // Bindings: K, W, U, Decay, GCum, State, H_out, VNew_out, Final(dst)
ggml_vk_preallocate_buffers(ctx, subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter,
} {src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum,
src_buf[5], scratch_h, scratch_vnew, dst_buf},
pc, { H, n_seqs, 1u });
if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx);
ggml_vk_sync_buffers(ctx, subctx);
}
vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size }; // Dispatch 3: Output (parallel across chunks)
vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size }; // Bindings: Q, K, H, VNew, GCum, Dst
vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size }; ggml_vk_dispatch_pipeline(ctx, subctx, pl_output,
vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size }; {src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf},
pc, { n_chunks * H, n_seqs, 1u });
ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra,
{src_buf[1], src_buf[2], src_buf[3], src_buf[4],
scratch_w, scratch_u, scratch_dec, scratch_gcum},
pc, { n_chunks * H, n_seqs, 1u });
ggml_vk_sync_buffers(ctx, subctx);
// Bindings: Q, K, W, U, Decay, GCum, State, Dst
ggml_vk_dispatch_pipeline(ctx, subctx, pl_fused,
{src_buf[0], src_buf[1], scratch_w, scratch_u,
scratch_dec, scratch_gcum, src_buf[5], dst_buf},
pc, { H, n_seqs, 1u });
} else {
// Fallback: 3-dispatch path (no coopmat)
vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter;
vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output;
ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1);
ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1);
ggml_pipeline_request_descriptor_sets(ctx, pl_output, 1);
const size_t h_size = (size_t)n_seqs * n_chunks * H * S_v * S_v * sizeof(float);
const size_t w_off = 0;
const size_t u_off = wu_size;
const size_t vn_off = 2 * wu_size;
const size_t dec_off = 3 * wu_size;
const size_t gcum_off = dec_off + d_size;
const size_t h_off = gcum_off + g_size;
const size_t total_scratch = h_off + h_size;
if (ctx->prealloc_size_split_k < total_scratch) {
ctx->prealloc_size_split_k = total_scratch;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size };
vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size };
vk_subbuffer scratch_vnew = { ctx->prealloc_split_k, vn_off, wu_size };
vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size };
vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size };
vk_subbuffer scratch_h = { ctx->prealloc_split_k, h_off, h_size };
ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra,
{src_buf[1], src_buf[2], src_buf[3], src_buf[4],
scratch_w, scratch_u, scratch_dec, scratch_gcum},
pc, { n_chunks * H, n_seqs, 1u });
ggml_vk_sync_buffers(ctx, subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter,
{src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum,
src_buf[5], scratch_h, scratch_vnew, dst_buf},
pc, { H, n_seqs, 1u });
ggml_vk_sync_buffers(ctx, subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pl_output,
{src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf},
pc, { n_chunks * H, n_seqs, 1u });
}
ctx->prealloc_split_k_need_sync = true; ctx->prealloc_split_k_need_sync = true;
} }

View File

@ -1,340 +0,0 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#include "types.glsl"
// Fused inter+output kernel for chunked gated delta net
//
// Merges the inter-chunk state propagation and output computation into
// one dispatch, eliminating VNew and H_snapshot scratch buffers plus
// one dispatch barrier.
//
// Step 1: Load K, Q, gcum → shared
// Step 2: A = Q @ K^T (coopmat)
// Step 3: Decay mask → sh_adecay + O_inter = Q @ state → dst (parallel)
// Step 4: vnew = U - W@state → sh_kv (f16), accumulate delta
// Step 5: O_intra = A_decayed @ vnew → dst (coopmat GEMM)
// Step 6: state = exp(decay) * state + delta
//
// Grid: (H, n_seqs, 1) — sequential chunk loop
// Workgroup: 256 threads = 4 subgroups
layout(constant_id = 0) const uint WG_SIZE = 256;
layout(constant_id = 1) const uint CHUNK_SIZE = 64;
layout(constant_id = 2) const uint S_V = 128;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint H;
uint n_tokens;
uint n_seqs;
uint sq1, sq2, sq3;
uint sv1, sv2, sv3;
uint sb1, sb2, sb3;
uint neq1, rq3;
uint n_chunks;
uint s_off;
};
layout(binding = 0) readonly buffer QBuf { float q_in[]; };
layout(binding = 1) readonly buffer KBuf { float k_in[]; };
layout(binding = 2) readonly buffer WBuf { float w_in[]; };
layout(binding = 3) readonly buffer UBuf { float u_in[]; };
layout(binding = 4) readonly buffer DecBuf { float decay_in[]; };
layout(binding = 5) readonly buffer GCumBuf { float gcum_in[]; };
layout(binding = 6) readonly buffer StBuf { float state_in[]; };
layout(binding = 7) buffer DstBuf { float dst[]; };
const uint TM = 16;
const uint TN = 16;
const uint TK = 16;
const uint C_TILES = CHUNK_SIZE / TM;
const uint D_TILES = S_V / TN;
const uint QK_STRIDE = S_V / 4 + 2;
const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2;
shared f16vec4 sh_q[CHUNK_SIZE * QK_STRIDE];
shared f16vec4 sh_kv[CHUNK_SIZE * QK_STRIDE];
shared vec4 sh_attn[CHUNK_SIZE * ATTN_V4_STRIDE];
shared f16vec4 sh_adecay[CHUNK_SIZE * ATTN_V4_STRIDE];
shared float sh_gcum[CHUNK_SIZE];
shared float sh_w[S_V];
shared float sh_kg[S_V];
void main() {
const uint tid = gl_LocalInvocationIndex;
const uint sg_id = gl_SubgroupID;
const uint si = gl_SubgroupInvocationID;
const uint head_id = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint col = tid;
const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;
const float scale = 1.0 / sqrt(float(S_V));
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;
const uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
// ================================================================
// Load state into registers (threads 0-127)
// ================================================================
float state[S_V];
if (col < S_V) {
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = state_in[state_base + i * S_V + col];
}
}
// ================================================================
// Chunk loop
// ================================================================
for (uint c = 0; c < n_chunks; c++) {
const uint chunk_start = c * CHUNK_SIZE;
const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start);
const uint wu_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE * S_V;
const uint gcum_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE;
const uint decay_idx = (seq_id * n_chunks + c) * H + head_id;
const float g_total = decay_in[decay_idx];
// ============================================================
// Step 1: Load K → sh_kv, Q → sh_q, gcum → sh_gcum
// ============================================================
if (tid < CHUNK_SIZE) {
sh_gcum[tid] = (tid < chunk_len) ? gcum_in[gcum_base + tid] : 0.0;
}
for (uint idx = tid; idx < CHUNK_SIZE * (S_V / 4); idx += WG_SIZE) {
const uint row = idx / (S_V / 4);
const uint col4 = idx % (S_V / 4);
f16vec4 q_val = f16vec4(0.0);
f16vec4 k_val = f16vec4(0.0);
if (row < chunk_len) {
const uint off = iq3 * sq3 + (chunk_start + row) * sq2 + iq1 * sq1 + col4 * 4;
q_val = f16vec4(q_in[off], q_in[off + 1], q_in[off + 2], q_in[off + 3]);
k_val = f16vec4(k_in[off], k_in[off + 1], k_in[off + 2], k_in[off + 3]);
}
sh_q[row * QK_STRIDE + col4] = q_val;
sh_kv[row * QK_STRIDE + col4] = k_val;
}
barrier();
// ============================================================
// Step 2: Q @ K^T (coopmat, f16→f32)
// ============================================================
coopmat<float, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> A_acc[C_TILES];
[[unroll]] for (uint tj = 0; tj < C_TILES; tj++) {
A_acc[tj] = coopmat<float, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
}
{
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> Q_mat;
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> KT_mat;
[[unroll]] for (uint dk = 0; dk < D_TILES; dk++) {
coopMatLoad(Q_mat, sh_q,
sg_id * TM * QK_STRIDE + dk * (TK / 4),
QK_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint tj = 0; tj < C_TILES; tj++) {
coopMatLoad(KT_mat, sh_kv,
tj * TN * QK_STRIDE + dk * (TK / 4),
QK_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
A_acc[tj] = coopMatMulAdd(Q_mat, KT_mat, A_acc[tj]);
}
}
}
[[unroll]] for (uint tj = 0; tj < C_TILES; tj++) {
coopMatStore(A_acc[tj], sh_attn,
sg_id * TM * ATTN_V4_STRIDE + tj * (TN / 4),
ATTN_V4_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
}
barrier();
// ============================================================
// Step 3: Decay mask + inter-chunk output (parallel, no conflict)
// ============================================================
for (uint idx = tid; idx < CHUNK_SIZE * (CHUNK_SIZE / 4); idx += WG_SIZE) {
const uint t = idx / (CHUNK_SIZE / 4);
const uint j4 = idx % (CHUNK_SIZE / 4);
f16vec4 val = f16vec4(0.0);
if (t < chunk_len) {
const float g_t = sh_gcum[t];
[[unroll]] for (uint k = 0; k < 4; k++) {
const uint j = j4 * 4 + k;
if (j <= t && j < chunk_len) {
val[k] = float16_t(clamp(
exp(g_t - sh_gcum[j]) * sh_attn[t * ATTN_V4_STRIDE + j4][k],
-65504.0, 65504.0));
}
}
}
sh_adecay[t * ATTN_V4_STRIDE + j4] = val;
}
if (col < S_V) {
for (uint t = 0; t < chunk_len; t++) {
float o_inter = 0.0;
[[unroll]] for (uint i = 0; i < S_V / 4; i++) {
vec4 q_f32 = vec4(sh_q[t * QK_STRIDE + i]);
o_inter += dot(q_f32,
vec4(state[i*4], state[i*4+1],
state[i*4+2], state[i*4+3]));
}
dst[attn_off + (chunk_start + t) * S_V * H + col] = exp(sh_gcum[t]) * o_inter * scale;
}
}
barrier();
// ============================================================
// Step 4: vnew = U - W@state → sh_kv, accumulate delta
// Threads 0-127: load w, compute vnew, write sh_kv via shuffle
// Threads 128-255: load k_gated into sh_kg
// ============================================================
float delta[S_V];
if (col < S_V) {
[[unroll]] for (uint i = 0; i < S_V; i++) {
delta[i] = 0.0;
}
}
for (uint t = 0; t < chunk_len; t++) {
if (col < S_V) {
sh_w[col] = w_in[wu_base + t * S_V + col];
}
if (tid >= 128 && tid < 256) {
const uint lc = tid - 128;
const float gcum_t = gcum_in[gcum_base + t];
const float decay_factor = exp(g_total - gcum_t);
const uint k_off = iq3 * sq3 + (chunk_start + t) * sq2 + iq1 * sq1;
sh_kg[lc] = k_in[k_off + lc] * decay_factor;
}
barrier();
if (col < S_V) {
float ws = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
ws += dot(vec4(sh_w[i], sh_w[i+1], sh_w[i+2], sh_w[i+3]),
vec4(state[i], state[i+1], state[i+2], state[i+3]));
}
float vnew = u_in[wu_base + t * S_V + col] - ws;
float vnew_scaled = clamp(vnew * scale, -65504.0, 65504.0);
float16_t v0 = float16_t(subgroupShuffle(vnew_scaled, si & ~3u));
float16_t v1 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 1u));
float16_t v2 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 2u));
float16_t v3 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 3u));
if ((si & 3u) == 0u) {
sh_kv[t * QK_STRIDE + (col >> 2)] = f16vec4(v0, v1, v2, v3);
}
[[unroll]] for (uint i = 0; i < S_V; i++) {
delta[i] += sh_kg[i] * vnew;
}
}
barrier();
}
// ============================================================
// Step 5: O_intra = A_decayed @ vnew → dst (coopmat GEMM)
// ============================================================
if (chunk_len == CHUNK_SIZE) {
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> A_mat;
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> V_mat;
[[unroll]] for (uint dn = 0; dn < D_TILES; dn++) {
coopmat<float, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> O_acc;
coopMatLoad(O_acc, dst,
attn_off + (chunk_start + sg_id * TM) * S_V * H + dn * TN,
S_V * H, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint dk = 0; dk < C_TILES; dk++) {
coopMatLoad(A_mat, sh_adecay,
sg_id * TM * ATTN_V4_STRIDE + dk * (TK / 4),
ATTN_V4_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(V_mat, sh_kv,
dk * TK * QK_STRIDE + dn * (TN / 4),
QK_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
O_acc = coopMatMulAdd(A_mat, V_mat, O_acc);
}
coopMatStore(O_acc, dst,
attn_off + (chunk_start + sg_id * TM) * S_V * H + dn * TN,
S_V * H, gl_CooperativeMatrixLayoutRowMajor);
}
} else {
if (col < S_V) {
const uint col4 = col / 4;
const uint comp = col % 4;
vec4 my_vnew_v4[CHUNK_SIZE / 4];
[[unroll]] for (uint j4 = 0; j4 < CHUNK_SIZE / 4; j4++) {
my_vnew_v4[j4] = vec4(
float(sh_kv[(j4*4 ) * QK_STRIDE + col4][comp]),
float(sh_kv[(j4*4+1) * QK_STRIDE + col4][comp]),
float(sh_kv[(j4*4+2) * QK_STRIDE + col4][comp]),
float(sh_kv[(j4*4+3) * QK_STRIDE + col4][comp]));
}
for (uint t = 0; t < chunk_len; t++) {
float o_intra = 0.0;
[[unroll]] for (uint j4 = 0; j4 < CHUNK_SIZE / 4; j4++) {
o_intra += dot(vec4(sh_adecay[t * ATTN_V4_STRIDE + j4]), my_vnew_v4[j4]);
}
dst[attn_off + (chunk_start + t) * S_V * H + col] += o_intra;
}
}
}
// ============================================================
// Step 6: State update
// ============================================================
if (col < S_V) {
const float total_decay = exp(g_total);
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = total_decay * state[i] + delta[i];
}
}
}
// ================================================================
// Write final state to dst
// ================================================================
if (col < S_V) {
[[unroll]] for (uint i = 0; i < S_V; i++) {
dst[s_off + state_base + i * S_V + col] = state[i];
}
}
}

View File

@ -992,7 +992,6 @@ void process_shaders() {
string_to_spv("gated_delta_net_chunk_inter_f32", "gated_delta_net_chunk_inter.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("gated_delta_net_chunk_inter_f32", "gated_delta_net_chunk_inter.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("gated_delta_net_chunk_output_f32", "gated_delta_net_chunk_output.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("gated_delta_net_chunk_output_f32", "gated_delta_net_chunk_output.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("gated_delta_net_chunk_output_cm1_f32", "gated_delta_net_chunk_output_cm1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("gated_delta_net_chunk_output_cm1_f32", "gated_delta_net_chunk_output_cm1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("gated_delta_net_chunk_fused_cm_f32", "gated_delta_net_chunk_fused_cm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));