cache q values into registers for KQ

This commit is contained in:
Ruben Ortlam 2026-02-07 17:15:55 +01:00
parent 3c2088121c
commit ca5ec63cfb
1 changed files with 11 additions and 5 deletions

View File

@ -184,11 +184,17 @@ void main() {
} }
} }
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { FLOAT_TYPEV4 Q_cache[rows_per_thread];
continue; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Q_cache[r] = Qf[tile_row(r) * qfstride + d * D_split + d_tid];
} }
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
#if BLOCK_SIZE > 1 #if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE; uint ib = coord / BLOCK_SIZE;
@ -198,7 +204,7 @@ void main() {
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif #endif
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qfstride + d * D_split + d_tid], K_Tf)); Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
} }
} }
} }