Merge the inter-chunk state propagation and output computation into a single dispatch, reducing the chunked pipeline from 3 dispatches to 2. State lives in registers across the sequential chunk loop. vnew is computed in-kernel and passed to the coopmat GEMM via shared memory (f16, packed with subgroup shuffles). This eliminates the VNew scratch buffer (wu_size) and H_snapshots buffer (h_size) — ~786KB/head/seq saved for PP-512. Architecture per chunk: Step 1: Load K, Q, gcum → shared (all 256 threads) Step 2: Q@K^T coopmat → sh_attn (all 256 threads) Step 3: Decay mask + O_inter = Q@state → dst (parallel) Step 4: vnew = U - W@state → sh_kv (128 threads + k_gated assist) Step 5: O_intra = A_decayed @ vnew coopmat GEMM → dst Step 6: state = exp(decay) * state + delta Shared memory: 63,744 / 65,536 bytes. 16/16 backend tests pass. |
||
|---|---|---|
| .. | ||
| cmake | ||
| include | ||
| src | ||
| .gitignore | ||
| CMakeLists.txt | ||