stage V loads through shmem

This commit is contained in:
Ruben Ortlam 2026-02-08 10:27:10 +01:00
parent b7b67f8742
commit 8236c453a5
1 changed files with 35 additions and 5 deletions

View File

@ -189,6 +189,7 @@ void main() {
}
if (K_LOAD_SHMEM != 0) {
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
@ -293,6 +294,30 @@ void main() {
}
}
if (K_LOAD_SHMEM != 0) {
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSV / 4);
uint32_t c = (idx + tid) / (HSV / 4);
if (c < Bc) {
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
V_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
#else
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
#endif
}
kvsh[c * kvsh_stride + d] = V_Tf;
}
}
barrier();
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
@ -305,14 +330,19 @@ void main() {
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
FLOAT_TYPEV4 Vf;
if (K_LOAD_SHMEM != 0) {
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
Vf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
#else
FLOAT_TYPEV4 Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += ACC_TYPEV4(Pf[r] * Vf);
}