vulkan: fuse w/k_gated broadcasts in chunked inter kernel
Load both s_w and s_kg before the first barrier instead of using separate barriers for each. Reduces per-token barriers from 3 to 2, eliminating 64 barriers per chunk. GDN per-op: 6818 → 5205 µs (-23.6%). 16/16 tests pass.
This commit is contained in:
parent
e22c2b2c85
commit
530e5bb117
|
|
@ -74,6 +74,12 @@ void main() {
|
|||
|
||||
for (uint t = 0; t < chunk_len; t++) {
|
||||
s_w[col] = w_in[wu_base + t * S_V + col];
|
||||
|
||||
const float g_cumsum_t = gcum_in[gcum_base + t];
|
||||
const float decay_factor = exp(g_total - g_cumsum_t);
|
||||
const uint t_global = chunk_start + t;
|
||||
const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1;
|
||||
s_kg[col] = k_in[k_off + col] * decay_factor;
|
||||
barrier();
|
||||
|
||||
float ws = 0.0;
|
||||
|
|
@ -87,15 +93,6 @@ void main() {
|
|||
float vnew = u_in[wu_base + t * S_V + col] - ws;
|
||||
vnew_out[wu_base + t * S_V + col] = vnew;
|
||||
|
||||
// K_gated[t] = k[t] * exp(g_total - g_cumsum[t])
|
||||
float g_cumsum_t = gcum_in[gcum_base + t];
|
||||
float decay_factor = exp(g_total - g_cumsum_t);
|
||||
|
||||
const uint t_global = chunk_start + t;
|
||||
const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1;
|
||||
s_kg[col] = k_in[k_off + col] * decay_factor;
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
delta[i] += s_kg[i] * vnew;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue