vulkan: optimize chunked intra kernel barrier and bank conflicts
Remove unnecessary barrier after A-matrix dot product writes. Each thread writes only to its own row; s_A isn't read cross-thread until forward substitution. Cuts A-matrix barriers from 128 to 65 (one per broadcast + one before forward sub). Pad s_A stride from 64 to 65 to eliminate bank conflicts in the W/U accumulation phase where all active threads read A(tid, j) with the same j value. GDN per-op: 5205 → 5136 µs. Combined with inter fusion: 6818 → 5136 µs (-24.7%). 16/16 tests pass.
This commit is contained in:
parent
530e5bb117
commit
88396c3923
|
|
@ -30,13 +30,14 @@ layout(binding = 5) writeonly buffer UBuf { float u_out[]; };
|
|||
layout(binding = 6) writeonly buffer DecayBuf { float decay_out[]; };
|
||||
layout(binding = 7) writeonly buffer GCumBuf { float gcum_out[]; }; // per-token g_cumsum
|
||||
|
||||
shared float s_A[CHUNK_SIZE * CHUNK_SIZE];
|
||||
const uint A_STRIDE = CHUNK_SIZE + 1;
|
||||
shared float s_A[CHUNK_SIZE * A_STRIDE];
|
||||
shared float s_decay[CHUNK_SIZE];
|
||||
shared float s_beta[CHUNK_SIZE];
|
||||
shared float s_k_broadcast[S_V];
|
||||
shared float s_v_broadcast[S_V];
|
||||
|
||||
#define A(i,j) s_A[(i) * CHUNK_SIZE + (j)]
|
||||
#define A(i,j) s_A[(i) * A_STRIDE + (j)]
|
||||
|
||||
void main() {
|
||||
const uint chunk_head = gl_WorkGroupID.x;
|
||||
|
|
@ -115,8 +116,8 @@ void main() {
|
|||
float decay_factor = exp(s_decay[tid] - s_decay[j]);
|
||||
A(tid, j) = -s_beta[tid] * dot_kk * decay_factor;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
barrier();
|
||||
|
||||
// Forward substitution: T = (I + A)^{-1}, in-place
|
||||
for (uint i = 1; i < chunk_len; i++) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue