vulkan: optimize chunked intra kernel barrier and bank conflicts

Remove unnecessary barrier after A-matrix dot product writes. Each
thread writes only to its own row; s_A isn't read cross-thread until
forward substitution. Cuts A-matrix barriers from 128 to 65 (one
per broadcast + one before forward sub).

Pad s_A stride from 64 to 65 to eliminate bank conflicts in the W/U
accumulation phase where all active threads read A(tid, j) with the
same j value.

GDN per-op: 5205 → 5136 µs. Combined with inter fusion: 6818 → 5136 µs
(-24.7%). 16/16 tests pass.
This commit is contained in:
Progeny Alpha 2026-03-14 22:48:11 -04:00
parent 530e5bb117
commit 88396c3923
1 changed files with 4 additions and 3 deletions

View File

@ -30,13 +30,14 @@ layout(binding = 5) writeonly buffer UBuf { float u_out[]; };
layout(binding = 6) writeonly buffer DecayBuf { float decay_out[]; };
layout(binding = 7) writeonly buffer GCumBuf { float gcum_out[]; }; // per-token g_cumsum
shared float s_A[CHUNK_SIZE * CHUNK_SIZE];
const uint A_STRIDE = CHUNK_SIZE + 1;
shared float s_A[CHUNK_SIZE * A_STRIDE];
shared float s_decay[CHUNK_SIZE];
shared float s_beta[CHUNK_SIZE];
shared float s_k_broadcast[S_V];
shared float s_v_broadcast[S_V];
#define A(i,j) s_A[(i) * CHUNK_SIZE + (j)]
#define A(i,j) s_A[(i) * A_STRIDE + (j)]
void main() {
const uint chunk_head = gl_WorkGroupID.x;
@ -115,8 +116,8 @@ void main() {
float decay_factor = exp(s_decay[tid] - s_decay[j]);
A(tid, j) = -s_beta[tid] * dot_kk * decay_factor;
}
barrier();
}
barrier();
// Forward substitution: T = (I + A)^{-1}, in-place
for (uint i = 1; i < chunk_len; i++) {