cache q values into registers for KQ
This commit is contained in:
parent
3c2088121c
commit
ca5ec63cfb
|
|
@ -184,11 +184,17 @@ void main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
FLOAT_TYPEV4 Q_cache[rows_per_thread];
|
||||||
continue;
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
|
Q_cache[r] = Qf[tile_row(r) * qfstride + d * D_split + d_tid];
|
||||||
}
|
}
|
||||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
|
||||||
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||||
|
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
#if BLOCK_SIZE > 1
|
#if BLOCK_SIZE > 1
|
||||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||||
uint ib = coord / BLOCK_SIZE;
|
uint ib = coord / BLOCK_SIZE;
|
||||||
|
|
@ -198,7 +204,7 @@ void main() {
|
||||||
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||||
#endif
|
#endif
|
||||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||||
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qfstride + d * D_split + d_tid], K_Tf));
|
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue