vulkan: final cleanup of chunked GDN intra/inter shaders
Intra: - Strip all section/inline comments to match codebase style - Add [[unroll]] to fixed-bound loops (A-matrix zero, W/U tile init/write) - Guard chunk_len==0 underflow on s_decay[chunk_len-1] Inter: - Strip final comment No functional changes. 16/16 tests pass.
This commit is contained in:
parent
88396c3923
commit
ab79f14b42
|
|
@ -105,7 +105,6 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
// Write final state to dst at s_off
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
final_out[s_off + state_base + col * S_V + i] = state[i];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@
|
|||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
// Intra-chunk WY decomposition for chunked gated delta net
|
||||
|
||||
layout(constant_id = 0) const uint S_V = 128;
|
||||
layout(constant_id = 1) const uint CHUNK_SIZE = 64;
|
||||
|
||||
|
|
@ -28,7 +26,7 @@ layout(binding = 3) readonly buffer BetaBuf { float beta_in[]; };
|
|||
layout(binding = 4) writeonly buffer WBuf { float w_out[]; };
|
||||
layout(binding = 5) writeonly buffer UBuf { float u_out[]; };
|
||||
layout(binding = 6) writeonly buffer DecayBuf { float decay_out[]; };
|
||||
layout(binding = 7) writeonly buffer GCumBuf { float gcum_out[]; }; // per-token g_cumsum
|
||||
layout(binding = 7) writeonly buffer GCumBuf { float gcum_out[]; };
|
||||
|
||||
const uint A_STRIDE = CHUNK_SIZE + 1;
|
||||
shared float s_A[CHUNK_SIZE * A_STRIDE];
|
||||
|
|
@ -54,7 +52,6 @@ void main() {
|
|||
const uint global_t = chunk_start + tid;
|
||||
const bool valid = tid < chunk_len;
|
||||
|
||||
// Load beta and gate
|
||||
if (valid) {
|
||||
const uint gb_off = seq_id * sb3 + global_t * sb2 + head_id * sb1;
|
||||
s_beta[tid] = beta_in[gb_off];
|
||||
|
|
@ -89,7 +86,7 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
for (uint j = 0; j < CHUNK_SIZE; j++) {
|
||||
[[unroll]] for (uint j = 0; j < CHUNK_SIZE; j++) {
|
||||
A(tid, j) = 0.0;
|
||||
}
|
||||
barrier();
|
||||
|
|
@ -119,7 +116,6 @@ void main() {
|
|||
}
|
||||
barrier();
|
||||
|
||||
// Forward substitution: T = (I + A)^{-1}, in-place
|
||||
for (uint i = 1; i < chunk_len; i++) {
|
||||
if (tid < i) {
|
||||
float sum = 0.0;
|
||||
|
|
@ -136,14 +132,13 @@ void main() {
|
|||
}
|
||||
barrier();
|
||||
|
||||
// W and U via tiled broadcast accumulation
|
||||
const uint out_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V;
|
||||
const uint TILE_D = 32;
|
||||
|
||||
for (uint d_start = 0; d_start < S_V; d_start += TILE_D) {
|
||||
float my_w[TILE_D];
|
||||
float my_u[TILE_D];
|
||||
for (uint d = 0; d < TILE_D; d++) {
|
||||
[[unroll]] for (uint d = 0; d < TILE_D; d++) {
|
||||
my_w[d] = 0.0;
|
||||
my_u[d] = 0.0;
|
||||
}
|
||||
|
|
@ -154,7 +149,6 @@ void main() {
|
|||
const uint vj_off = seq_id * sv3 + j_global * sv2 + head_id * sv1;
|
||||
float eg = exp(s_decay[j]);
|
||||
|
||||
// Broadcast tile of k[j] and v[j]
|
||||
for (uint d = tid; d < S_V; d += CHUNK_SIZE) {
|
||||
if (d >= d_start && d < d_start + TILE_D) {
|
||||
s_k_broadcast[d] = k_in[kj_off + d] * eg;
|
||||
|
|
@ -173,17 +167,15 @@ void main() {
|
|||
barrier();
|
||||
}
|
||||
|
||||
// Write tile to global memory
|
||||
if (valid) {
|
||||
for (uint d = 0; d < TILE_D; d++) {
|
||||
[[unroll]] for (uint d = 0; d < TILE_D; d++) {
|
||||
w_out[out_base + tid * S_V + d_start + d] = my_w[d];
|
||||
u_out[out_base + tid * S_V + d_start + d] = my_u[d];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Output total chunk decay
|
||||
if (tid == 0) {
|
||||
if (tid == 0 && chunk_len > 0) {
|
||||
const uint decay_idx = (seq_id * n_chunks + chunk_id) * H + head_id;
|
||||
decay_out[decay_idx] = s_decay[chunk_len - 1];
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue