vulkan: fuse w/k_gated broadcasts in chunked inter kernel

Load both s_w and s_kg before the first barrier instead of using
separate barriers for each. Reduces per-token barriers from 3 to 2,
eliminating 64 barriers per chunk.

GDN per-op: 6818 → 5205 µs (-23.6%). 16/16 tests pass.
This commit is contained in:
Progeny Alpha 2026-03-14 22:32:46 -04:00
parent e22c2b2c85
commit 530e5bb117
1 changed files with 6 additions and 9 deletions

View File

@ -74,6 +74,12 @@ void main() {
for (uint t = 0; t < chunk_len; t++) {
s_w[col] = w_in[wu_base + t * S_V + col];
const float g_cumsum_t = gcum_in[gcum_base + t];
const float decay_factor = exp(g_total - g_cumsum_t);
const uint t_global = chunk_start + t;
const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1;
s_kg[col] = k_in[k_off + col] * decay_factor;
barrier();
float ws = 0.0;
@ -87,15 +93,6 @@ void main() {
float vnew = u_in[wu_base + t * S_V + col] - ws;
vnew_out[wu_base + t * S_V + col] = vnew;
// K_gated[t] = k[t] * exp(g_total - g_cumsum[t])
float g_cumsum_t = gcum_in[gcum_base + t];
float decay_factor = exp(g_total - g_cumsum_t);
const uint t_global = chunk_start + t;
const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1;
s_kg[col] = k_in[k_off + col] * decay_factor;
barrier();
[[unroll]] for (uint i = 0; i < S_V; i++) {
delta[i] += s_kg[i] * vnew;
}